@@ -17,11 +17,11 @@ |
||
| 17 | 17 | */ |
| 18 | 18 | interface ITriggerableProvider extends IProvider { |
| 19 | 19 | |
| 20 | - /** |
|
| 21 | - * Called when new tasks for this provider are coming in and there are currently |
|
| 22 | - * no tasks running for this provider's task type |
|
| 23 | - * |
|
| 24 | - * @since 33.0.0 |
|
| 25 | - */ |
|
| 26 | - public function trigger(): void; |
|
| 20 | + /** |
|
| 21 | + * Called when new tasks for this provider are coming in and there are currently |
|
| 22 | + * no tasks running for this provider's task type |
|
| 23 | + * |
|
| 24 | + * @since 33.0.0 |
|
| 25 | + */ |
|
| 26 | + public function trigger(): void; |
|
| 27 | 27 | } |
@@ -22,262 +22,262 @@ |
||
| 22 | 22 | * @extends QBMapper<Task> |
| 23 | 23 | */ |
| 24 | 24 | class TaskMapper extends QBMapper { |
| 25 | - public function __construct( |
|
| 26 | - IDBConnection $db, |
|
| 27 | - private ITimeFactory $timeFactory, |
|
| 28 | - ) { |
|
| 29 | - parent::__construct($db, 'taskprocessing_tasks', Task::class); |
|
| 30 | - } |
|
| 25 | + public function __construct( |
|
| 26 | + IDBConnection $db, |
|
| 27 | + private ITimeFactory $timeFactory, |
|
| 28 | + ) { |
|
| 29 | + parent::__construct($db, 'taskprocessing_tasks', Task::class); |
|
| 30 | + } |
|
| 31 | 31 | |
| 32 | - /** |
|
| 33 | - * @param int $id |
|
| 34 | - * @return Task |
|
| 35 | - * @throws Exception |
|
| 36 | - * @throws DoesNotExistException |
|
| 37 | - * @throws MultipleObjectsReturnedException |
|
| 38 | - */ |
|
| 39 | - public function find(int $id): Task { |
|
| 40 | - $qb = $this->db->getQueryBuilder(); |
|
| 41 | - $qb->select(Task::$columns) |
|
| 42 | - ->from($this->tableName) |
|
| 43 | - ->where($qb->expr()->eq('id', $qb->createPositionalParameter($id))); |
|
| 44 | - return $this->findEntity($qb); |
|
| 45 | - } |
|
| 32 | + /** |
|
| 33 | + * @param int $id |
|
| 34 | + * @return Task |
|
| 35 | + * @throws Exception |
|
| 36 | + * @throws DoesNotExistException |
|
| 37 | + * @throws MultipleObjectsReturnedException |
|
| 38 | + */ |
|
| 39 | + public function find(int $id): Task { |
|
| 40 | + $qb = $this->db->getQueryBuilder(); |
|
| 41 | + $qb->select(Task::$columns) |
|
| 42 | + ->from($this->tableName) |
|
| 43 | + ->where($qb->expr()->eq('id', $qb->createPositionalParameter($id))); |
|
| 44 | + return $this->findEntity($qb); |
|
| 45 | + } |
|
| 46 | 46 | |
| 47 | - /** |
|
| 48 | - * @param list<string> $taskTypes |
|
| 49 | - * @param list<int> $taskIdsToIgnore |
|
| 50 | - * @return Task |
|
| 51 | - * @throws DoesNotExistException |
|
| 52 | - * @throws Exception |
|
| 53 | - */ |
|
| 54 | - public function findOldestScheduledByType(array $taskTypes, array $taskIdsToIgnore): Task { |
|
| 55 | - $qb = $this->db->getQueryBuilder(); |
|
| 56 | - $qb->select(Task::$columns) |
|
| 57 | - ->from($this->tableName) |
|
| 58 | - ->where($qb->expr()->eq('status', $qb->createPositionalParameter(\OCP\TaskProcessing\Task::STATUS_SCHEDULED, IQueryBuilder::PARAM_INT))) |
|
| 59 | - ->setMaxResults(1) |
|
| 60 | - ->orderBy('last_updated', 'ASC'); |
|
| 47 | + /** |
|
| 48 | + * @param list<string> $taskTypes |
|
| 49 | + * @param list<int> $taskIdsToIgnore |
|
| 50 | + * @return Task |
|
| 51 | + * @throws DoesNotExistException |
|
| 52 | + * @throws Exception |
|
| 53 | + */ |
|
| 54 | + public function findOldestScheduledByType(array $taskTypes, array $taskIdsToIgnore): Task { |
|
| 55 | + $qb = $this->db->getQueryBuilder(); |
|
| 56 | + $qb->select(Task::$columns) |
|
| 57 | + ->from($this->tableName) |
|
| 58 | + ->where($qb->expr()->eq('status', $qb->createPositionalParameter(\OCP\TaskProcessing\Task::STATUS_SCHEDULED, IQueryBuilder::PARAM_INT))) |
|
| 59 | + ->setMaxResults(1) |
|
| 60 | + ->orderBy('last_updated', 'ASC'); |
|
| 61 | 61 | |
| 62 | - if (!empty($taskTypes)) { |
|
| 63 | - $filter = []; |
|
| 64 | - foreach ($taskTypes as $taskType) { |
|
| 65 | - $filter[] = $qb->expr()->eq('type', $qb->createPositionalParameter($taskType)); |
|
| 66 | - } |
|
| 62 | + if (!empty($taskTypes)) { |
|
| 63 | + $filter = []; |
|
| 64 | + foreach ($taskTypes as $taskType) { |
|
| 65 | + $filter[] = $qb->expr()->eq('type', $qb->createPositionalParameter($taskType)); |
|
| 66 | + } |
|
| 67 | 67 | |
| 68 | - $qb->andWhere($qb->expr()->orX(...$filter)); |
|
| 69 | - } |
|
| 68 | + $qb->andWhere($qb->expr()->orX(...$filter)); |
|
| 69 | + } |
|
| 70 | 70 | |
| 71 | - if (!empty($taskIdsToIgnore)) { |
|
| 72 | - $qb->andWhere($qb->expr()->notIn('id', $qb->createNamedParameter($taskIdsToIgnore, IQueryBuilder::PARAM_INT_ARRAY))); |
|
| 73 | - } |
|
| 71 | + if (!empty($taskIdsToIgnore)) { |
|
| 72 | + $qb->andWhere($qb->expr()->notIn('id', $qb->createNamedParameter($taskIdsToIgnore, IQueryBuilder::PARAM_INT_ARRAY))); |
|
| 73 | + } |
|
| 74 | 74 | |
| 75 | - return $this->findEntity($qb); |
|
| 76 | - } |
|
| 75 | + return $this->findEntity($qb); |
|
| 76 | + } |
|
| 77 | 77 | |
| 78 | - /** |
|
| 79 | - * @param int $id |
|
| 80 | - * @param string|null $userId |
|
| 81 | - * @return Task |
|
| 82 | - * @throws DoesNotExistException |
|
| 83 | - * @throws Exception |
|
| 84 | - * @throws MultipleObjectsReturnedException |
|
| 85 | - */ |
|
| 86 | - public function findByIdAndUser(int $id, ?string $userId): Task { |
|
| 87 | - $qb = $this->db->getQueryBuilder(); |
|
| 88 | - $qb->select(Task::$columns) |
|
| 89 | - ->from($this->tableName) |
|
| 90 | - ->where($qb->expr()->eq('id', $qb->createPositionalParameter($id))); |
|
| 91 | - if ($userId === null) { |
|
| 92 | - $qb->andWhere($qb->expr()->isNull('user_id')); |
|
| 93 | - } else { |
|
| 94 | - $qb->andWhere($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))); |
|
| 95 | - } |
|
| 96 | - return $this->findEntity($qb); |
|
| 97 | - } |
|
| 78 | + /** |
|
| 79 | + * @param int $id |
|
| 80 | + * @param string|null $userId |
|
| 81 | + * @return Task |
|
| 82 | + * @throws DoesNotExistException |
|
| 83 | + * @throws Exception |
|
| 84 | + * @throws MultipleObjectsReturnedException |
|
| 85 | + */ |
|
| 86 | + public function findByIdAndUser(int $id, ?string $userId): Task { |
|
| 87 | + $qb = $this->db->getQueryBuilder(); |
|
| 88 | + $qb->select(Task::$columns) |
|
| 89 | + ->from($this->tableName) |
|
| 90 | + ->where($qb->expr()->eq('id', $qb->createPositionalParameter($id))); |
|
| 91 | + if ($userId === null) { |
|
| 92 | + $qb->andWhere($qb->expr()->isNull('user_id')); |
|
| 93 | + } else { |
|
| 94 | + $qb->andWhere($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))); |
|
| 95 | + } |
|
| 96 | + return $this->findEntity($qb); |
|
| 97 | + } |
|
| 98 | 98 | |
| 99 | - /** |
|
| 100 | - * @param string|null $userId |
|
| 101 | - * @param string|null $taskType |
|
| 102 | - * @param string|null $customId |
|
| 103 | - * @return list<Task> |
|
| 104 | - * @throws Exception |
|
| 105 | - */ |
|
| 106 | - public function findByUserAndTaskType(?string $userId, ?string $taskType = null, ?string $customId = null): array { |
|
| 107 | - $qb = $this->db->getQueryBuilder(); |
|
| 108 | - $qb->select(Task::$columns) |
|
| 109 | - ->from($this->tableName) |
|
| 110 | - ->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))); |
|
| 111 | - if ($taskType !== null) { |
|
| 112 | - $qb->andWhere($qb->expr()->eq('type', $qb->createPositionalParameter($taskType))); |
|
| 113 | - } |
|
| 114 | - if ($customId !== null) { |
|
| 115 | - $qb->andWhere($qb->expr()->eq('custom_id', $qb->createPositionalParameter($customId))); |
|
| 116 | - } |
|
| 117 | - return $this->findEntities($qb); |
|
| 118 | - } |
|
| 99 | + /** |
|
| 100 | + * @param string|null $userId |
|
| 101 | + * @param string|null $taskType |
|
| 102 | + * @param string|null $customId |
|
| 103 | + * @return list<Task> |
|
| 104 | + * @throws Exception |
|
| 105 | + */ |
|
| 106 | + public function findByUserAndTaskType(?string $userId, ?string $taskType = null, ?string $customId = null): array { |
|
| 107 | + $qb = $this->db->getQueryBuilder(); |
|
| 108 | + $qb->select(Task::$columns) |
|
| 109 | + ->from($this->tableName) |
|
| 110 | + ->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))); |
|
| 111 | + if ($taskType !== null) { |
|
| 112 | + $qb->andWhere($qb->expr()->eq('type', $qb->createPositionalParameter($taskType))); |
|
| 113 | + } |
|
| 114 | + if ($customId !== null) { |
|
| 115 | + $qb->andWhere($qb->expr()->eq('custom_id', $qb->createPositionalParameter($customId))); |
|
| 116 | + } |
|
| 117 | + return $this->findEntities($qb); |
|
| 118 | + } |
|
| 119 | 119 | |
| 120 | - /** |
|
| 121 | - * @param string $userId |
|
| 122 | - * @param string $appId |
|
| 123 | - * @param string|null $customId |
|
| 124 | - * @return list<Task> |
|
| 125 | - * @throws Exception |
|
| 126 | - */ |
|
| 127 | - public function findUserTasksByApp(?string $userId, string $appId, ?string $customId = null): array { |
|
| 128 | - $qb = $this->db->getQueryBuilder(); |
|
| 129 | - $qb->select(Task::$columns) |
|
| 130 | - ->from($this->tableName) |
|
| 131 | - ->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))) |
|
| 132 | - ->andWhere($qb->expr()->eq('app_id', $qb->createPositionalParameter($appId))); |
|
| 133 | - if ($customId !== null) { |
|
| 134 | - $qb->andWhere($qb->expr()->eq('custom_id', $qb->createPositionalParameter($customId))); |
|
| 135 | - } |
|
| 136 | - return $this->findEntities($qb); |
|
| 137 | - } |
|
| 120 | + /** |
|
| 121 | + * @param string $userId |
|
| 122 | + * @param string $appId |
|
| 123 | + * @param string|null $customId |
|
| 124 | + * @return list<Task> |
|
| 125 | + * @throws Exception |
|
| 126 | + */ |
|
| 127 | + public function findUserTasksByApp(?string $userId, string $appId, ?string $customId = null): array { |
|
| 128 | + $qb = $this->db->getQueryBuilder(); |
|
| 129 | + $qb->select(Task::$columns) |
|
| 130 | + ->from($this->tableName) |
|
| 131 | + ->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))) |
|
| 132 | + ->andWhere($qb->expr()->eq('app_id', $qb->createPositionalParameter($appId))); |
|
| 133 | + if ($customId !== null) { |
|
| 134 | + $qb->andWhere($qb->expr()->eq('custom_id', $qb->createPositionalParameter($customId))); |
|
| 135 | + } |
|
| 136 | + return $this->findEntities($qb); |
|
| 137 | + } |
|
| 138 | 138 | |
| 139 | - /** |
|
| 140 | - * @param string|null $userId |
|
| 141 | - * @param string|null $taskType |
|
| 142 | - * @param string|null $appId |
|
| 143 | - * @param string|null $customId |
|
| 144 | - * @param int|null $status |
|
| 145 | - * @param int|null $scheduleAfter |
|
| 146 | - * @param int|null $endedBefore |
|
| 147 | - * @return list<Task> |
|
| 148 | - * @throws Exception |
|
| 149 | - */ |
|
| 150 | - public function findTasks( |
|
| 151 | - ?string $userId, ?string $taskType = null, ?string $appId = null, ?string $customId = null, |
|
| 152 | - ?int $status = null, ?int $scheduleAfter = null, ?int $endedBefore = null): array { |
|
| 153 | - $qb = $this->db->getQueryBuilder(); |
|
| 154 | - $qb->select(Task::$columns) |
|
| 155 | - ->from($this->tableName); |
|
| 139 | + /** |
|
| 140 | + * @param string|null $userId |
|
| 141 | + * @param string|null $taskType |
|
| 142 | + * @param string|null $appId |
|
| 143 | + * @param string|null $customId |
|
| 144 | + * @param int|null $status |
|
| 145 | + * @param int|null $scheduleAfter |
|
| 146 | + * @param int|null $endedBefore |
|
| 147 | + * @return list<Task> |
|
| 148 | + * @throws Exception |
|
| 149 | + */ |
|
| 150 | + public function findTasks( |
|
| 151 | + ?string $userId, ?string $taskType = null, ?string $appId = null, ?string $customId = null, |
|
| 152 | + ?int $status = null, ?int $scheduleAfter = null, ?int $endedBefore = null): array { |
|
| 153 | + $qb = $this->db->getQueryBuilder(); |
|
| 154 | + $qb->select(Task::$columns) |
|
| 155 | + ->from($this->tableName); |
|
| 156 | 156 | |
| 157 | - // empty string: no userId filter |
|
| 158 | - if ($userId !== '') { |
|
| 159 | - $qb->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))); |
|
| 160 | - } |
|
| 161 | - if ($taskType !== null) { |
|
| 162 | - $qb->andWhere($qb->expr()->eq('type', $qb->createPositionalParameter($taskType))); |
|
| 163 | - } |
|
| 164 | - if ($appId !== null) { |
|
| 165 | - $qb->andWhere($qb->expr()->eq('app_id', $qb->createPositionalParameter($appId))); |
|
| 166 | - } |
|
| 167 | - if ($customId !== null) { |
|
| 168 | - $qb->andWhere($qb->expr()->eq('custom_id', $qb->createPositionalParameter($customId))); |
|
| 169 | - } |
|
| 170 | - if ($status !== null) { |
|
| 171 | - $qb->andWhere($qb->expr()->eq('status', $qb->createPositionalParameter($status, IQueryBuilder::PARAM_INT))); |
|
| 172 | - } |
|
| 173 | - if ($scheduleAfter !== null) { |
|
| 174 | - $qb->andWhere($qb->expr()->isNotNull('scheduled_at')); |
|
| 175 | - $qb->andWhere($qb->expr()->gt('scheduled_at', $qb->createPositionalParameter($scheduleAfter, IQueryBuilder::PARAM_INT))); |
|
| 176 | - } |
|
| 177 | - if ($endedBefore !== null) { |
|
| 178 | - $qb->andWhere($qb->expr()->isNotNull('ended_at')); |
|
| 179 | - $qb->andWhere($qb->expr()->lt('ended_at', $qb->createPositionalParameter($endedBefore, IQueryBuilder::PARAM_INT))); |
|
| 180 | - } |
|
| 181 | - return $this->findEntities($qb); |
|
| 182 | - } |
|
| 157 | + // empty string: no userId filter |
|
| 158 | + if ($userId !== '') { |
|
| 159 | + $qb->where($qb->expr()->eq('user_id', $qb->createPositionalParameter($userId))); |
|
| 160 | + } |
|
| 161 | + if ($taskType !== null) { |
|
| 162 | + $qb->andWhere($qb->expr()->eq('type', $qb->createPositionalParameter($taskType))); |
|
| 163 | + } |
|
| 164 | + if ($appId !== null) { |
|
| 165 | + $qb->andWhere($qb->expr()->eq('app_id', $qb->createPositionalParameter($appId))); |
|
| 166 | + } |
|
| 167 | + if ($customId !== null) { |
|
| 168 | + $qb->andWhere($qb->expr()->eq('custom_id', $qb->createPositionalParameter($customId))); |
|
| 169 | + } |
|
| 170 | + if ($status !== null) { |
|
| 171 | + $qb->andWhere($qb->expr()->eq('status', $qb->createPositionalParameter($status, IQueryBuilder::PARAM_INT))); |
|
| 172 | + } |
|
| 173 | + if ($scheduleAfter !== null) { |
|
| 174 | + $qb->andWhere($qb->expr()->isNotNull('scheduled_at')); |
|
| 175 | + $qb->andWhere($qb->expr()->gt('scheduled_at', $qb->createPositionalParameter($scheduleAfter, IQueryBuilder::PARAM_INT))); |
|
| 176 | + } |
|
| 177 | + if ($endedBefore !== null) { |
|
| 178 | + $qb->andWhere($qb->expr()->isNotNull('ended_at')); |
|
| 179 | + $qb->andWhere($qb->expr()->lt('ended_at', $qb->createPositionalParameter($endedBefore, IQueryBuilder::PARAM_INT))); |
|
| 180 | + } |
|
| 181 | + return $this->findEntities($qb); |
|
| 182 | + } |
|
| 183 | 183 | |
| 184 | - /** |
|
| 185 | - * @param int $timeout |
|
| 186 | - * @param bool $force If true, ignore the allow_cleanup flag |
|
| 187 | - * @return int the number of deleted tasks |
|
| 188 | - * @throws Exception |
|
| 189 | - */ |
|
| 190 | - public function deleteOlderThan(int $timeout, bool $force = false): int { |
|
| 191 | - $qb = $this->db->getQueryBuilder(); |
|
| 192 | - $qb->delete($this->tableName) |
|
| 193 | - ->where($qb->expr()->lt('last_updated', $qb->createPositionalParameter($this->timeFactory->getDateTime()->getTimestamp() - $timeout))); |
|
| 194 | - if (!$force) { |
|
| 195 | - $qb->andWhere($qb->expr()->eq('allow_cleanup', $qb->createPositionalParameter(1, IQueryBuilder::PARAM_INT))); |
|
| 196 | - } |
|
| 197 | - return $qb->executeStatement(); |
|
| 198 | - } |
|
| 184 | + /** |
|
| 185 | + * @param int $timeout |
|
| 186 | + * @param bool $force If true, ignore the allow_cleanup flag |
|
| 187 | + * @return int the number of deleted tasks |
|
| 188 | + * @throws Exception |
|
| 189 | + */ |
|
| 190 | + public function deleteOlderThan(int $timeout, bool $force = false): int { |
|
| 191 | + $qb = $this->db->getQueryBuilder(); |
|
| 192 | + $qb->delete($this->tableName) |
|
| 193 | + ->where($qb->expr()->lt('last_updated', $qb->createPositionalParameter($this->timeFactory->getDateTime()->getTimestamp() - $timeout))); |
|
| 194 | + if (!$force) { |
|
| 195 | + $qb->andWhere($qb->expr()->eq('allow_cleanup', $qb->createPositionalParameter(1, IQueryBuilder::PARAM_INT))); |
|
| 196 | + } |
|
| 197 | + return $qb->executeStatement(); |
|
| 198 | + } |
|
| 199 | 199 | |
| 200 | - /** |
|
| 201 | - * @param int $timeout |
|
| 202 | - * @param bool $force If true, ignore the allow_cleanup flag |
|
| 203 | - * @return \Generator<Task> |
|
| 204 | - * @throws Exception |
|
| 205 | - */ |
|
| 206 | - public function getTasksToCleanup(int $timeout, bool $force = false): \Generator { |
|
| 207 | - $qb = $this->db->getQueryBuilder(); |
|
| 208 | - $qb->select(Task::$columns) |
|
| 209 | - ->from($this->tableName) |
|
| 210 | - ->where($qb->expr()->lt('last_updated', $qb->createPositionalParameter($this->timeFactory->getDateTime()->getTimestamp() - $timeout))); |
|
| 211 | - if (!$force) { |
|
| 212 | - $qb->andWhere($qb->expr()->eq('allow_cleanup', $qb->createPositionalParameter(1, IQueryBuilder::PARAM_INT))); |
|
| 213 | - } |
|
| 214 | - foreach ($this->yieldEntities($qb) as $entity) { |
|
| 215 | - yield $entity; |
|
| 216 | - }; |
|
| 217 | - } |
|
| 200 | + /** |
|
| 201 | + * @param int $timeout |
|
| 202 | + * @param bool $force If true, ignore the allow_cleanup flag |
|
| 203 | + * @return \Generator<Task> |
|
| 204 | + * @throws Exception |
|
| 205 | + */ |
|
| 206 | + public function getTasksToCleanup(int $timeout, bool $force = false): \Generator { |
|
| 207 | + $qb = $this->db->getQueryBuilder(); |
|
| 208 | + $qb->select(Task::$columns) |
|
| 209 | + ->from($this->tableName) |
|
| 210 | + ->where($qb->expr()->lt('last_updated', $qb->createPositionalParameter($this->timeFactory->getDateTime()->getTimestamp() - $timeout))); |
|
| 211 | + if (!$force) { |
|
| 212 | + $qb->andWhere($qb->expr()->eq('allow_cleanup', $qb->createPositionalParameter(1, IQueryBuilder::PARAM_INT))); |
|
| 213 | + } |
|
| 214 | + foreach ($this->yieldEntities($qb) as $entity) { |
|
| 215 | + yield $entity; |
|
| 216 | + }; |
|
| 217 | + } |
|
| 218 | 218 | |
| 219 | - public function update(Entity $entity): Entity { |
|
| 220 | - $entity->setLastUpdated($this->timeFactory->now()->getTimestamp()); |
|
| 221 | - return parent::update($entity); |
|
| 222 | - } |
|
| 219 | + public function update(Entity $entity): Entity { |
|
| 220 | + $entity->setLastUpdated($this->timeFactory->now()->getTimestamp()); |
|
| 221 | + return parent::update($entity); |
|
| 222 | + } |
|
| 223 | 223 | |
| 224 | - public function lockTask(Entity $entity): int { |
|
| 225 | - $qb = $this->db->getQueryBuilder(); |
|
| 226 | - $qb->update($this->tableName) |
|
| 227 | - ->set('status', $qb->createPositionalParameter(\OCP\TaskProcessing\Task::STATUS_RUNNING, IQueryBuilder::PARAM_INT)) |
|
| 228 | - ->where($qb->expr()->eq('id', $qb->createPositionalParameter($entity->getId(), IQueryBuilder::PARAM_INT))) |
|
| 229 | - ->andWhere($qb->expr()->neq('status', $qb->createPositionalParameter(2, IQueryBuilder::PARAM_INT))); |
|
| 230 | - try { |
|
| 231 | - return $qb->executeStatement(); |
|
| 232 | - } catch (Exception) { |
|
| 233 | - return 0; |
|
| 234 | - } |
|
| 235 | - } |
|
| 224 | + public function lockTask(Entity $entity): int { |
|
| 225 | + $qb = $this->db->getQueryBuilder(); |
|
| 226 | + $qb->update($this->tableName) |
|
| 227 | + ->set('status', $qb->createPositionalParameter(\OCP\TaskProcessing\Task::STATUS_RUNNING, IQueryBuilder::PARAM_INT)) |
|
| 228 | + ->where($qb->expr()->eq('id', $qb->createPositionalParameter($entity->getId(), IQueryBuilder::PARAM_INT))) |
|
| 229 | + ->andWhere($qb->expr()->neq('status', $qb->createPositionalParameter(2, IQueryBuilder::PARAM_INT))); |
|
| 230 | + try { |
|
| 231 | + return $qb->executeStatement(); |
|
| 232 | + } catch (Exception) { |
|
| 233 | + return 0; |
|
| 234 | + } |
|
| 235 | + } |
|
| 236 | 236 | |
| 237 | - /** |
|
| 238 | - * @param list<string> $taskTypes |
|
| 239 | - * @param list<int> $taskIdsToIgnore |
|
| 240 | - * @param int $numberOfTasks |
|
| 241 | - * @return list<Task> |
|
| 242 | - * @throws Exception |
|
| 243 | - */ |
|
| 244 | - public function findNOldestScheduledByType(array $taskTypes, array $taskIdsToIgnore, int $numberOfTasks) { |
|
| 245 | - $qb = $this->db->getQueryBuilder(); |
|
| 246 | - $qb->select(Task::$columns) |
|
| 247 | - ->from($this->tableName) |
|
| 248 | - ->where($qb->expr()->eq('status', $qb->createPositionalParameter(\OCP\TaskProcessing\Task::STATUS_SCHEDULED, IQueryBuilder::PARAM_INT))) |
|
| 249 | - ->setMaxResults($numberOfTasks) |
|
| 250 | - ->orderBy('last_updated', 'ASC'); |
|
| 237 | + /** |
|
| 238 | + * @param list<string> $taskTypes |
|
| 239 | + * @param list<int> $taskIdsToIgnore |
|
| 240 | + * @param int $numberOfTasks |
|
| 241 | + * @return list<Task> |
|
| 242 | + * @throws Exception |
|
| 243 | + */ |
|
| 244 | + public function findNOldestScheduledByType(array $taskTypes, array $taskIdsToIgnore, int $numberOfTasks) { |
|
| 245 | + $qb = $this->db->getQueryBuilder(); |
|
| 246 | + $qb->select(Task::$columns) |
|
| 247 | + ->from($this->tableName) |
|
| 248 | + ->where($qb->expr()->eq('status', $qb->createPositionalParameter(\OCP\TaskProcessing\Task::STATUS_SCHEDULED, IQueryBuilder::PARAM_INT))) |
|
| 249 | + ->setMaxResults($numberOfTasks) |
|
| 250 | + ->orderBy('last_updated', 'ASC'); |
|
| 251 | 251 | |
| 252 | - if (!empty($taskTypes)) { |
|
| 253 | - $filter = []; |
|
| 254 | - foreach ($taskTypes as $taskType) { |
|
| 255 | - $filter[] = $qb->expr()->eq('type', $qb->createPositionalParameter($taskType)); |
|
| 256 | - } |
|
| 252 | + if (!empty($taskTypes)) { |
|
| 253 | + $filter = []; |
|
| 254 | + foreach ($taskTypes as $taskType) { |
|
| 255 | + $filter[] = $qb->expr()->eq('type', $qb->createPositionalParameter($taskType)); |
|
| 256 | + } |
|
| 257 | 257 | |
| 258 | - $qb->andWhere($qb->expr()->orX(...$filter)); |
|
| 259 | - } |
|
| 258 | + $qb->andWhere($qb->expr()->orX(...$filter)); |
|
| 259 | + } |
|
| 260 | 260 | |
| 261 | - if (!empty($taskIdsToIgnore)) { |
|
| 262 | - $qb->andWhere($qb->expr()->notIn('id', $qb->createNamedParameter($taskIdsToIgnore, IQueryBuilder::PARAM_INT_ARRAY))); |
|
| 263 | - } |
|
| 261 | + if (!empty($taskIdsToIgnore)) { |
|
| 262 | + $qb->andWhere($qb->expr()->notIn('id', $qb->createNamedParameter($taskIdsToIgnore, IQueryBuilder::PARAM_INT_ARRAY))); |
|
| 263 | + } |
|
| 264 | 264 | |
| 265 | - return $this->findEntities($qb); |
|
| 266 | - } |
|
| 265 | + return $this->findEntities($qb); |
|
| 266 | + } |
|
| 267 | 267 | |
| 268 | - /** |
|
| 269 | - * @throws Exception |
|
| 270 | - */ |
|
| 271 | - public function hasRunningTasksForTaskType(string $getTaskTypeId): bool { |
|
| 272 | - $qb = $this->db->getQueryBuilder(); |
|
| 273 | - $qb->select('id') |
|
| 274 | - ->from($this->tableName); |
|
| 275 | - $qb->where($qb->expr()->eq('type', $qb->createNamedParameter($getTaskTypeId))); |
|
| 276 | - $qb->andWhere($qb->expr()->eq('status', $qb->createNamedParameter(\OCP\TaskProcessing\Task::STATUS_RUNNING, IQueryBuilder::PARAM_INT))); |
|
| 277 | - $qb->setMaxResults(1); |
|
| 278 | - $result = $qb->executeQuery(); |
|
| 279 | - $hasRunningTasks = $result->fetch() !== false; |
|
| 280 | - $result->closeCursor(); |
|
| 281 | - return $hasRunningTasks; |
|
| 282 | - } |
|
| 268 | + /** |
|
| 269 | + * @throws Exception |
|
| 270 | + */ |
|
| 271 | + public function hasRunningTasksForTaskType(string $getTaskTypeId): bool { |
|
| 272 | + $qb = $this->db->getQueryBuilder(); |
|
| 273 | + $qb->select('id') |
|
| 274 | + ->from($this->tableName); |
|
| 275 | + $qb->where($qb->expr()->eq('type', $qb->createNamedParameter($getTaskTypeId))); |
|
| 276 | + $qb->andWhere($qb->expr()->eq('status', $qb->createNamedParameter(\OCP\TaskProcessing\Task::STATUS_RUNNING, IQueryBuilder::PARAM_INT))); |
|
| 277 | + $qb->setMaxResults(1); |
|
| 278 | + $result = $qb->executeQuery(); |
|
| 279 | + $hasRunningTasks = $result->fetch() !== false; |
|
| 280 | + $result->closeCursor(); |
|
| 281 | + return $hasRunningTasks; |
|
| 282 | + } |
|
| 283 | 283 | } |
@@ -72,1628 +72,1628 @@ |
||
| 72 | 72 | |
| 73 | 73 | class Manager implements IManager { |
| 74 | 74 | |
| 75 | - public const LEGACY_PREFIX_TEXTPROCESSING = 'legacy:TextProcessing:'; |
|
| 76 | - public const LEGACY_PREFIX_TEXTTOIMAGE = 'legacy:TextToImage:'; |
|
| 77 | - public const LEGACY_PREFIX_SPEECHTOTEXT = 'legacy:SpeechToText:'; |
|
| 78 | - |
|
| 79 | - public const LAZY_CONFIG_KEYS = [ |
|
| 80 | - 'ai.taskprocessing_type_preferences', |
|
| 81 | - 'ai.taskprocessing_provider_preferences', |
|
| 82 | - ]; |
|
| 83 | - |
|
| 84 | - public const MAX_TASK_AGE_SECONDS = 60 * 60 * 24 * 31 * 6; // 6 months |
|
| 85 | - |
|
| 86 | - private const TASK_TYPES_CACHE_KEY = 'available_task_types_v3'; |
|
| 87 | - private const TASK_TYPE_IDS_CACHE_KEY = 'available_task_type_ids'; |
|
| 88 | - |
|
| 89 | - /** @var list<IProvider>|null */ |
|
| 90 | - private ?array $providers = null; |
|
| 91 | - |
|
| 92 | - /** |
|
| 93 | - * @var array<array-key,array{name: string, description: string, inputShape: ShapeDescriptor[], inputShapeEnumValues: ShapeEnumValue[][], inputShapeDefaults: array<array-key, numeric|string>, isInternal: bool, optionalInputShape: ShapeDescriptor[], optionalInputShapeEnumValues: ShapeEnumValue[][], optionalInputShapeDefaults: array<array-key, numeric|string>, outputShape: ShapeDescriptor[], outputShapeEnumValues: ShapeEnumValue[][], optionalOutputShape: ShapeDescriptor[], optionalOutputShapeEnumValues: ShapeEnumValue[][]}> |
|
| 94 | - */ |
|
| 95 | - private ?array $availableTaskTypes = null; |
|
| 96 | - |
|
| 97 | - /** @var list<string>|null */ |
|
| 98 | - private ?array $availableTaskTypeIds = null; |
|
| 99 | - |
|
| 100 | - private IAppData $appData; |
|
| 101 | - private ?array $preferences = null; |
|
| 102 | - private ?array $providersById = null; |
|
| 103 | - |
|
| 104 | - /** @var ITaskType[]|null */ |
|
| 105 | - private ?array $taskTypes = null; |
|
| 106 | - private ICache $distributedCache; |
|
| 107 | - |
|
| 108 | - private ?GetTaskProcessingProvidersEvent $eventResult = null; |
|
| 109 | - |
|
| 110 | - public function __construct( |
|
| 111 | - private IAppConfig $appConfig, |
|
| 112 | - private Coordinator $coordinator, |
|
| 113 | - private IServerContainer $serverContainer, |
|
| 114 | - private LoggerInterface $logger, |
|
| 115 | - private TaskMapper $taskMapper, |
|
| 116 | - private IJobList $jobList, |
|
| 117 | - private IEventDispatcher $dispatcher, |
|
| 118 | - IAppDataFactory $appDataFactory, |
|
| 119 | - private IRootFolder $rootFolder, |
|
| 120 | - private \OCP\TextToImage\IManager $textToImageManager, |
|
| 121 | - private IUserMountCache $userMountCache, |
|
| 122 | - private IClientService $clientService, |
|
| 123 | - private IAppManager $appManager, |
|
| 124 | - private IUserManager $userManager, |
|
| 125 | - private IUserSession $userSession, |
|
| 126 | - ICacheFactory $cacheFactory, |
|
| 127 | - private IFactory $l10nFactory, |
|
| 128 | - ) { |
|
| 129 | - $this->appData = $appDataFactory->get('core'); |
|
| 130 | - $this->distributedCache = $cacheFactory->createDistributed('task_processing::'); |
|
| 131 | - } |
|
| 132 | - |
|
| 133 | - |
|
| 134 | - /** |
|
| 135 | - * This is almost a copy of textProcessingManager->getProviders |
|
| 136 | - * to avoid a dependency cycle between TextProcessingManager and TaskProcessingManager |
|
| 137 | - */ |
|
| 138 | - private function _getRawTextProcessingProviders(): array { |
|
| 139 | - $context = $this->coordinator->getRegistrationContext(); |
|
| 140 | - if ($context === null) { |
|
| 141 | - return []; |
|
| 142 | - } |
|
| 143 | - |
|
| 144 | - $providers = []; |
|
| 145 | - |
|
| 146 | - foreach ($context->getTextProcessingProviders() as $providerServiceRegistration) { |
|
| 147 | - $class = $providerServiceRegistration->getService(); |
|
| 148 | - try { |
|
| 149 | - $providers[$class] = $this->serverContainer->get($class); |
|
| 150 | - } catch (\Throwable $e) { |
|
| 151 | - $this->logger->error('Failed to load Text processing provider ' . $class, [ |
|
| 152 | - 'exception' => $e, |
|
| 153 | - ]); |
|
| 154 | - } |
|
| 155 | - } |
|
| 156 | - |
|
| 157 | - return $providers; |
|
| 158 | - } |
|
| 159 | - |
|
| 160 | - private function _getTextProcessingProviders(): array { |
|
| 161 | - $oldProviders = $this->_getRawTextProcessingProviders(); |
|
| 162 | - $newProviders = []; |
|
| 163 | - foreach ($oldProviders as $oldProvider) { |
|
| 164 | - $provider = new class($oldProvider) implements IProvider, ISynchronousProvider { |
|
| 165 | - private \OCP\TextProcessing\IProvider $provider; |
|
| 166 | - |
|
| 167 | - public function __construct(\OCP\TextProcessing\IProvider $provider) { |
|
| 168 | - $this->provider = $provider; |
|
| 169 | - } |
|
| 170 | - |
|
| 171 | - public function getId(): string { |
|
| 172 | - if ($this->provider instanceof \OCP\TextProcessing\IProviderWithId) { |
|
| 173 | - return $this->provider->getId(); |
|
| 174 | - } |
|
| 175 | - return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider::class; |
|
| 176 | - } |
|
| 177 | - |
|
| 178 | - public function getName(): string { |
|
| 179 | - return $this->provider->getName(); |
|
| 180 | - } |
|
| 181 | - |
|
| 182 | - public function getTaskTypeId(): string { |
|
| 183 | - return match ($this->provider->getTaskType()) { |
|
| 184 | - \OCP\TextProcessing\FreePromptTaskType::class => TextToText::ID, |
|
| 185 | - \OCP\TextProcessing\HeadlineTaskType::class => TextToTextHeadline::ID, |
|
| 186 | - \OCP\TextProcessing\TopicsTaskType::class => TextToTextTopics::ID, |
|
| 187 | - \OCP\TextProcessing\SummaryTaskType::class => TextToTextSummary::ID, |
|
| 188 | - default => Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider->getTaskType(), |
|
| 189 | - }; |
|
| 190 | - } |
|
| 191 | - |
|
| 192 | - public function getExpectedRuntime(): int { |
|
| 193 | - if ($this->provider instanceof \OCP\TextProcessing\IProviderWithExpectedRuntime) { |
|
| 194 | - return $this->provider->getExpectedRuntime(); |
|
| 195 | - } |
|
| 196 | - return 60; |
|
| 197 | - } |
|
| 198 | - |
|
| 199 | - public function getOptionalInputShape(): array { |
|
| 200 | - return []; |
|
| 201 | - } |
|
| 202 | - |
|
| 203 | - public function getOptionalOutputShape(): array { |
|
| 204 | - return []; |
|
| 205 | - } |
|
| 206 | - |
|
| 207 | - public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 208 | - if ($this->provider instanceof \OCP\TextProcessing\IProviderWithUserId) { |
|
| 209 | - $this->provider->setUserId($userId); |
|
| 210 | - } |
|
| 211 | - try { |
|
| 212 | - return ['output' => $this->provider->process($input['input'])]; |
|
| 213 | - } catch (\RuntimeException $e) { |
|
| 214 | - throw new ProcessingException($e->getMessage(), 0, $e); |
|
| 215 | - } |
|
| 216 | - } |
|
| 217 | - |
|
| 218 | - public function getInputShapeEnumValues(): array { |
|
| 219 | - return []; |
|
| 220 | - } |
|
| 221 | - |
|
| 222 | - public function getInputShapeDefaults(): array { |
|
| 223 | - return []; |
|
| 224 | - } |
|
| 225 | - |
|
| 226 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 227 | - return []; |
|
| 228 | - } |
|
| 229 | - |
|
| 230 | - public function getOptionalInputShapeDefaults(): array { |
|
| 231 | - return []; |
|
| 232 | - } |
|
| 233 | - |
|
| 234 | - public function getOutputShapeEnumValues(): array { |
|
| 235 | - return []; |
|
| 236 | - } |
|
| 237 | - |
|
| 238 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 239 | - return []; |
|
| 240 | - } |
|
| 241 | - }; |
|
| 242 | - $newProviders[$provider->getId()] = $provider; |
|
| 243 | - } |
|
| 244 | - |
|
| 245 | - return $newProviders; |
|
| 246 | - } |
|
| 247 | - |
|
| 248 | - /** |
|
| 249 | - * @return ITaskType[] |
|
| 250 | - */ |
|
| 251 | - private function _getTextProcessingTaskTypes(): array { |
|
| 252 | - $oldProviders = $this->_getRawTextProcessingProviders(); |
|
| 253 | - $newTaskTypes = []; |
|
| 254 | - foreach ($oldProviders as $oldProvider) { |
|
| 255 | - // These are already implemented in the TaskProcessing realm |
|
| 256 | - if (in_array($oldProvider->getTaskType(), [ |
|
| 257 | - \OCP\TextProcessing\FreePromptTaskType::class, |
|
| 258 | - \OCP\TextProcessing\HeadlineTaskType::class, |
|
| 259 | - \OCP\TextProcessing\TopicsTaskType::class, |
|
| 260 | - \OCP\TextProcessing\SummaryTaskType::class |
|
| 261 | - ], true)) { |
|
| 262 | - continue; |
|
| 263 | - } |
|
| 264 | - $taskType = new class($oldProvider->getTaskType()) implements ITaskType { |
|
| 265 | - private string $oldTaskTypeClass; |
|
| 266 | - private \OCP\TextProcessing\ITaskType $oldTaskType; |
|
| 267 | - private IL10N $l; |
|
| 268 | - |
|
| 269 | - public function __construct(string $oldTaskTypeClass) { |
|
| 270 | - $this->oldTaskTypeClass = $oldTaskTypeClass; |
|
| 271 | - $this->oldTaskType = \OCP\Server::get($oldTaskTypeClass); |
|
| 272 | - $this->l = \OCP\Server::get(IFactory::class)->get('core'); |
|
| 273 | - } |
|
| 274 | - |
|
| 275 | - public function getId(): string { |
|
| 276 | - return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->oldTaskTypeClass; |
|
| 277 | - } |
|
| 278 | - |
|
| 279 | - public function getName(): string { |
|
| 280 | - return $this->oldTaskType->getName(); |
|
| 281 | - } |
|
| 282 | - |
|
| 283 | - public function getDescription(): string { |
|
| 284 | - return $this->oldTaskType->getDescription(); |
|
| 285 | - } |
|
| 286 | - |
|
| 287 | - public function getInputShape(): array { |
|
| 288 | - return ['input' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)]; |
|
| 289 | - } |
|
| 290 | - |
|
| 291 | - public function getOutputShape(): array { |
|
| 292 | - return ['output' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)]; |
|
| 293 | - } |
|
| 294 | - }; |
|
| 295 | - $newTaskTypes[$taskType->getId()] = $taskType; |
|
| 296 | - } |
|
| 297 | - |
|
| 298 | - return $newTaskTypes; |
|
| 299 | - } |
|
| 300 | - |
|
| 301 | - /** |
|
| 302 | - * @return IProvider[] |
|
| 303 | - */ |
|
| 304 | - private function _getTextToImageProviders(): array { |
|
| 305 | - $oldProviders = $this->textToImageManager->getProviders(); |
|
| 306 | - $newProviders = []; |
|
| 307 | - foreach ($oldProviders as $oldProvider) { |
|
| 308 | - $newProvider = new class($oldProvider, $this->appData) implements IProvider, ISynchronousProvider { |
|
| 309 | - private \OCP\TextToImage\IProvider $provider; |
|
| 310 | - private IAppData $appData; |
|
| 311 | - |
|
| 312 | - public function __construct(\OCP\TextToImage\IProvider $provider, IAppData $appData) { |
|
| 313 | - $this->provider = $provider; |
|
| 314 | - $this->appData = $appData; |
|
| 315 | - } |
|
| 316 | - |
|
| 317 | - public function getId(): string { |
|
| 318 | - return Manager::LEGACY_PREFIX_TEXTTOIMAGE . $this->provider->getId(); |
|
| 319 | - } |
|
| 320 | - |
|
| 321 | - public function getName(): string { |
|
| 322 | - return $this->provider->getName(); |
|
| 323 | - } |
|
| 324 | - |
|
| 325 | - public function getTaskTypeId(): string { |
|
| 326 | - return TextToImage::ID; |
|
| 327 | - } |
|
| 328 | - |
|
| 329 | - public function getExpectedRuntime(): int { |
|
| 330 | - return $this->provider->getExpectedRuntime(); |
|
| 331 | - } |
|
| 332 | - |
|
| 333 | - public function getOptionalInputShape(): array { |
|
| 334 | - return []; |
|
| 335 | - } |
|
| 336 | - |
|
| 337 | - public function getOptionalOutputShape(): array { |
|
| 338 | - return []; |
|
| 339 | - } |
|
| 340 | - |
|
| 341 | - public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 342 | - try { |
|
| 343 | - $folder = $this->appData->getFolder('text2image'); |
|
| 344 | - } catch (\OCP\Files\NotFoundException) { |
|
| 345 | - $folder = $this->appData->newFolder('text2image'); |
|
| 346 | - } |
|
| 347 | - $resources = []; |
|
| 348 | - $files = []; |
|
| 349 | - for ($i = 0; $i < $input['numberOfImages']; $i++) { |
|
| 350 | - $file = $folder->newFile(time() . '-' . rand(1, 100000) . '-' . $i); |
|
| 351 | - $files[] = $file; |
|
| 352 | - $resource = $file->write(); |
|
| 353 | - if ($resource !== false && $resource !== true && is_resource($resource)) { |
|
| 354 | - $resources[] = $resource; |
|
| 355 | - } else { |
|
| 356 | - throw new ProcessingException('Text2Image generation using provider "' . $this->getName() . '" failed: Couldn\'t open file to write.'); |
|
| 357 | - } |
|
| 358 | - } |
|
| 359 | - if ($this->provider instanceof \OCP\TextToImage\IProviderWithUserId) { |
|
| 360 | - $this->provider->setUserId($userId); |
|
| 361 | - } |
|
| 362 | - try { |
|
| 363 | - $this->provider->generate($input['input'], $resources); |
|
| 364 | - } catch (\RuntimeException $e) { |
|
| 365 | - throw new ProcessingException($e->getMessage(), 0, $e); |
|
| 366 | - } |
|
| 367 | - for ($i = 0; $i < $input['numberOfImages']; $i++) { |
|
| 368 | - if (is_resource($resources[$i])) { |
|
| 369 | - // If $resource hasn't been closed yet, we'll do that here |
|
| 370 | - fclose($resources[$i]); |
|
| 371 | - } |
|
| 372 | - } |
|
| 373 | - return ['images' => array_map(fn (ISimpleFile $file) => $file->getContent(), $files)]; |
|
| 374 | - } |
|
| 375 | - |
|
| 376 | - public function getInputShapeEnumValues(): array { |
|
| 377 | - return []; |
|
| 378 | - } |
|
| 379 | - |
|
| 380 | - public function getInputShapeDefaults(): array { |
|
| 381 | - return []; |
|
| 382 | - } |
|
| 383 | - |
|
| 384 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 385 | - return []; |
|
| 386 | - } |
|
| 387 | - |
|
| 388 | - public function getOptionalInputShapeDefaults(): array { |
|
| 389 | - return []; |
|
| 390 | - } |
|
| 391 | - |
|
| 392 | - public function getOutputShapeEnumValues(): array { |
|
| 393 | - return []; |
|
| 394 | - } |
|
| 395 | - |
|
| 396 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 397 | - return []; |
|
| 398 | - } |
|
| 399 | - }; |
|
| 400 | - $newProviders[$newProvider->getId()] = $newProvider; |
|
| 401 | - } |
|
| 402 | - |
|
| 403 | - return $newProviders; |
|
| 404 | - } |
|
| 405 | - |
|
| 406 | - /** |
|
| 407 | - * This is almost a copy of SpeechToTextManager->getProviders |
|
| 408 | - * to avoid a dependency cycle between SpeechToTextManager and TaskProcessingManager |
|
| 409 | - */ |
|
| 410 | - private function _getRawSpeechToTextProviders(): array { |
|
| 411 | - $context = $this->coordinator->getRegistrationContext(); |
|
| 412 | - if ($context === null) { |
|
| 413 | - return []; |
|
| 414 | - } |
|
| 415 | - $providers = []; |
|
| 416 | - foreach ($context->getSpeechToTextProviders() as $providerServiceRegistration) { |
|
| 417 | - $class = $providerServiceRegistration->getService(); |
|
| 418 | - try { |
|
| 419 | - $providers[$class] = $this->serverContainer->get($class); |
|
| 420 | - } catch (NotFoundExceptionInterface|ContainerExceptionInterface|\Throwable $e) { |
|
| 421 | - $this->logger->error('Failed to load SpeechToText provider ' . $class, [ |
|
| 422 | - 'exception' => $e, |
|
| 423 | - ]); |
|
| 424 | - } |
|
| 425 | - } |
|
| 426 | - |
|
| 427 | - return $providers; |
|
| 428 | - } |
|
| 429 | - |
|
| 430 | - /** |
|
| 431 | - * @return IProvider[] |
|
| 432 | - */ |
|
| 433 | - private function _getSpeechToTextProviders(): array { |
|
| 434 | - $oldProviders = $this->_getRawSpeechToTextProviders(); |
|
| 435 | - $newProviders = []; |
|
| 436 | - foreach ($oldProviders as $oldProvider) { |
|
| 437 | - $newProvider = new class($oldProvider, $this->rootFolder, $this->appData) implements IProvider, ISynchronousProvider { |
|
| 438 | - private ISpeechToTextProvider $provider; |
|
| 439 | - private IAppData $appData; |
|
| 440 | - |
|
| 441 | - private IRootFolder $rootFolder; |
|
| 442 | - |
|
| 443 | - public function __construct(ISpeechToTextProvider $provider, IRootFolder $rootFolder, IAppData $appData) { |
|
| 444 | - $this->provider = $provider; |
|
| 445 | - $this->rootFolder = $rootFolder; |
|
| 446 | - $this->appData = $appData; |
|
| 447 | - } |
|
| 448 | - |
|
| 449 | - public function getId(): string { |
|
| 450 | - if ($this->provider instanceof ISpeechToTextProviderWithId) { |
|
| 451 | - return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider->getId(); |
|
| 452 | - } |
|
| 453 | - return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider::class; |
|
| 454 | - } |
|
| 455 | - |
|
| 456 | - public function getName(): string { |
|
| 457 | - return $this->provider->getName(); |
|
| 458 | - } |
|
| 459 | - |
|
| 460 | - public function getTaskTypeId(): string { |
|
| 461 | - return AudioToText::ID; |
|
| 462 | - } |
|
| 463 | - |
|
| 464 | - public function getExpectedRuntime(): int { |
|
| 465 | - return 60; |
|
| 466 | - } |
|
| 467 | - |
|
| 468 | - public function getOptionalInputShape(): array { |
|
| 469 | - return []; |
|
| 470 | - } |
|
| 471 | - |
|
| 472 | - public function getOptionalOutputShape(): array { |
|
| 473 | - return []; |
|
| 474 | - } |
|
| 475 | - |
|
| 476 | - public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 477 | - if ($this->provider instanceof \OCP\SpeechToText\ISpeechToTextProviderWithUserId) { |
|
| 478 | - $this->provider->setUserId($userId); |
|
| 479 | - } |
|
| 480 | - try { |
|
| 481 | - $result = $this->provider->transcribeFile($input['input']); |
|
| 482 | - } catch (\RuntimeException $e) { |
|
| 483 | - throw new ProcessingException($e->getMessage(), 0, $e); |
|
| 484 | - } |
|
| 485 | - return ['output' => $result]; |
|
| 486 | - } |
|
| 487 | - |
|
| 488 | - public function getInputShapeEnumValues(): array { |
|
| 489 | - return []; |
|
| 490 | - } |
|
| 491 | - |
|
| 492 | - public function getInputShapeDefaults(): array { |
|
| 493 | - return []; |
|
| 494 | - } |
|
| 495 | - |
|
| 496 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 497 | - return []; |
|
| 498 | - } |
|
| 499 | - |
|
| 500 | - public function getOptionalInputShapeDefaults(): array { |
|
| 501 | - return []; |
|
| 502 | - } |
|
| 503 | - |
|
| 504 | - public function getOutputShapeEnumValues(): array { |
|
| 505 | - return []; |
|
| 506 | - } |
|
| 507 | - |
|
| 508 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 509 | - return []; |
|
| 510 | - } |
|
| 511 | - }; |
|
| 512 | - $newProviders[$newProvider->getId()] = $newProvider; |
|
| 513 | - } |
|
| 514 | - |
|
| 515 | - return $newProviders; |
|
| 516 | - } |
|
| 517 | - |
|
| 518 | - /** |
|
| 519 | - * Dispatches the event to collect external providers and task types. |
|
| 520 | - * Caches the result within the request. |
|
| 521 | - */ |
|
| 522 | - private function dispatchGetProvidersEvent(): GetTaskProcessingProvidersEvent { |
|
| 523 | - if ($this->eventResult !== null) { |
|
| 524 | - return $this->eventResult; |
|
| 525 | - } |
|
| 526 | - |
|
| 527 | - $this->eventResult = new GetTaskProcessingProvidersEvent(); |
|
| 528 | - $this->dispatcher->dispatchTyped($this->eventResult); |
|
| 529 | - return $this->eventResult ; |
|
| 530 | - } |
|
| 531 | - |
|
| 532 | - /** |
|
| 533 | - * @return IProvider[] |
|
| 534 | - */ |
|
| 535 | - private function _getProviders(): array { |
|
| 536 | - $context = $this->coordinator->getRegistrationContext(); |
|
| 537 | - |
|
| 538 | - if ($context === null) { |
|
| 539 | - return []; |
|
| 540 | - } |
|
| 541 | - |
|
| 542 | - $providers = []; |
|
| 543 | - |
|
| 544 | - foreach ($context->getTaskProcessingProviders() as $providerServiceRegistration) { |
|
| 545 | - $class = $providerServiceRegistration->getService(); |
|
| 546 | - try { |
|
| 547 | - /** @var IProvider $provider */ |
|
| 548 | - $provider = $this->serverContainer->get($class); |
|
| 549 | - if (isset($providers[$provider->getId()])) { |
|
| 550 | - $this->logger->warning('Task processing provider ' . $class . ' is using ID ' . $provider->getId() . ' which is already used by ' . $providers[$provider->getId()]::class); |
|
| 551 | - } |
|
| 552 | - $providers[$provider->getId()] = $provider; |
|
| 553 | - } catch (\Throwable $e) { |
|
| 554 | - $this->logger->error('Failed to load task processing provider ' . $class, [ |
|
| 555 | - 'exception' => $e, |
|
| 556 | - ]); |
|
| 557 | - } |
|
| 558 | - } |
|
| 559 | - |
|
| 560 | - $event = $this->dispatchGetProvidersEvent(); |
|
| 561 | - $externalProviders = $event->getProviders(); |
|
| 562 | - foreach ($externalProviders as $provider) { |
|
| 563 | - if (!isset($providers[$provider->getId()])) { |
|
| 564 | - $providers[$provider->getId()] = $provider; |
|
| 565 | - } else { |
|
| 566 | - $this->logger->info('Skipping external task processing provider with ID ' . $provider->getId() . ' because a local provider with the same ID already exists.'); |
|
| 567 | - } |
|
| 568 | - } |
|
| 569 | - |
|
| 570 | - $providers += $this->_getTextProcessingProviders() + $this->_getTextToImageProviders() + $this->_getSpeechToTextProviders(); |
|
| 571 | - |
|
| 572 | - return $providers; |
|
| 573 | - } |
|
| 574 | - |
|
| 575 | - /** |
|
| 576 | - * @return ITaskType[] |
|
| 577 | - */ |
|
| 578 | - private function _getTaskTypes(): array { |
|
| 579 | - $context = $this->coordinator->getRegistrationContext(); |
|
| 580 | - |
|
| 581 | - if ($context === null) { |
|
| 582 | - return []; |
|
| 583 | - } |
|
| 584 | - |
|
| 585 | - if ($this->taskTypes !== null) { |
|
| 586 | - return $this->taskTypes; |
|
| 587 | - } |
|
| 588 | - |
|
| 589 | - // Default task types |
|
| 590 | - $taskTypes = [ |
|
| 591 | - \OCP\TaskProcessing\TaskTypes\TextToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToText::class), |
|
| 592 | - \OCP\TaskProcessing\TaskTypes\TextToTextTopics::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextTopics::class), |
|
| 593 | - \OCP\TaskProcessing\TaskTypes\TextToTextHeadline::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextHeadline::class), |
|
| 594 | - \OCP\TaskProcessing\TaskTypes\TextToTextSummary::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextSummary::class), |
|
| 595 | - \OCP\TaskProcessing\TaskTypes\TextToTextFormalization::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextFormalization::class), |
|
| 596 | - \OCP\TaskProcessing\TaskTypes\TextToTextSimplification::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextSimplification::class), |
|
| 597 | - \OCP\TaskProcessing\TaskTypes\TextToTextChat::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextChat::class), |
|
| 598 | - \OCP\TaskProcessing\TaskTypes\TextToTextTranslate::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextTranslate::class), |
|
| 599 | - \OCP\TaskProcessing\TaskTypes\TextToTextReformulation::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextReformulation::class), |
|
| 600 | - \OCP\TaskProcessing\TaskTypes\TextToImage::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToImage::class), |
|
| 601 | - \OCP\TaskProcessing\TaskTypes\AudioToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AudioToText::class), |
|
| 602 | - \OCP\TaskProcessing\TaskTypes\ContextWrite::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextWrite::class), |
|
| 603 | - \OCP\TaskProcessing\TaskTypes\GenerateEmoji::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\GenerateEmoji::class), |
|
| 604 | - \OCP\TaskProcessing\TaskTypes\TextToTextChangeTone::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextChangeTone::class), |
|
| 605 | - \OCP\TaskProcessing\TaskTypes\TextToTextChatWithTools::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextChatWithTools::class), |
|
| 606 | - \OCP\TaskProcessing\TaskTypes\ContextAgentInteraction::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextAgentInteraction::class), |
|
| 607 | - \OCP\TaskProcessing\TaskTypes\TextToTextProofread::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextProofread::class), |
|
| 608 | - \OCP\TaskProcessing\TaskTypes\TextToSpeech::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToSpeech::class), |
|
| 609 | - \OCP\TaskProcessing\TaskTypes\AudioToAudioChat::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AudioToAudioChat::class), |
|
| 610 | - \OCP\TaskProcessing\TaskTypes\ContextAgentAudioInteraction::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextAgentAudioInteraction::class), |
|
| 611 | - \OCP\TaskProcessing\TaskTypes\AnalyzeImages::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AnalyzeImages::class), |
|
| 612 | - ]; |
|
| 613 | - |
|
| 614 | - foreach ($context->getTaskProcessingTaskTypes() as $providerServiceRegistration) { |
|
| 615 | - $class = $providerServiceRegistration->getService(); |
|
| 616 | - try { |
|
| 617 | - /** @var ITaskType $provider */ |
|
| 618 | - $taskType = $this->serverContainer->get($class); |
|
| 619 | - if (isset($taskTypes[$taskType->getId()])) { |
|
| 620 | - $this->logger->warning('Task processing task type ' . $class . ' is using ID ' . $taskType->getId() . ' which is already used by ' . $taskTypes[$taskType->getId()]::class); |
|
| 621 | - } |
|
| 622 | - $taskTypes[$taskType->getId()] = $taskType; |
|
| 623 | - } catch (\Throwable $e) { |
|
| 624 | - $this->logger->error('Failed to load task processing task type ' . $class, [ |
|
| 625 | - 'exception' => $e, |
|
| 626 | - ]); |
|
| 627 | - } |
|
| 628 | - } |
|
| 629 | - |
|
| 630 | - $event = $this->dispatchGetProvidersEvent(); |
|
| 631 | - $externalTaskTypes = $event->getTaskTypes(); |
|
| 632 | - foreach ($externalTaskTypes as $taskType) { |
|
| 633 | - if (isset($taskTypes[$taskType->getId()])) { |
|
| 634 | - $this->logger->warning('External task processing task type is using ID ' . $taskType->getId() . ' which is already used by a locally registered task type (' . get_class($taskTypes[$taskType->getId()]) . ')'); |
|
| 635 | - } |
|
| 636 | - $taskTypes[$taskType->getId()] = $taskType; |
|
| 637 | - } |
|
| 638 | - |
|
| 639 | - $taskTypes += $this->_getTextProcessingTaskTypes(); |
|
| 640 | - |
|
| 641 | - $this->taskTypes = $taskTypes; |
|
| 642 | - return $this->taskTypes; |
|
| 643 | - } |
|
| 644 | - |
|
| 645 | - /** |
|
| 646 | - * @return array |
|
| 647 | - */ |
|
| 648 | - private function _getTaskTypeSettings(): array { |
|
| 649 | - try { |
|
| 650 | - $json = $this->appConfig->getValueString('core', 'ai.taskprocessing_type_preferences', '', lazy: true); |
|
| 651 | - if ($json === '') { |
|
| 652 | - return []; |
|
| 653 | - } |
|
| 654 | - return json_decode($json, true, flags: JSON_THROW_ON_ERROR); |
|
| 655 | - } catch (\JsonException $e) { |
|
| 656 | - $this->logger->error('Failed to get settings. JSON Error in ai.taskprocessing_type_preferences', ['exception' => $e]); |
|
| 657 | - $taskTypeSettings = []; |
|
| 658 | - $taskTypes = $this->_getTaskTypes(); |
|
| 659 | - foreach ($taskTypes as $taskType) { |
|
| 660 | - $taskTypeSettings[$taskType->getId()] = false; |
|
| 661 | - }; |
|
| 662 | - |
|
| 663 | - return $taskTypeSettings; |
|
| 664 | - } |
|
| 665 | - |
|
| 666 | - } |
|
| 667 | - |
|
| 668 | - /** |
|
| 669 | - * @param ShapeDescriptor[] $spec |
|
| 670 | - * @param array<array-key, string|numeric> $defaults |
|
| 671 | - * @param array<array-key, ShapeEnumValue[]> $enumValues |
|
| 672 | - * @param array $io |
|
| 673 | - * @param bool $optional |
|
| 674 | - * @return void |
|
| 675 | - * @throws ValidationException |
|
| 676 | - */ |
|
| 677 | - private static function validateInput(array $spec, array $defaults, array $enumValues, array $io, bool $optional = false): void { |
|
| 678 | - foreach ($spec as $key => $descriptor) { |
|
| 679 | - $type = $descriptor->getShapeType(); |
|
| 680 | - if (!isset($io[$key])) { |
|
| 681 | - if ($optional) { |
|
| 682 | - continue; |
|
| 683 | - } |
|
| 684 | - if (isset($defaults[$key])) { |
|
| 685 | - if (EShapeType::getScalarType($type) !== $type) { |
|
| 686 | - throw new ValidationException('Provider tried to set a default value for a non-scalar slot'); |
|
| 687 | - } |
|
| 688 | - if (EShapeType::isFileType($type)) { |
|
| 689 | - throw new ValidationException('Provider tried to set a default value for a slot that is not text or number'); |
|
| 690 | - } |
|
| 691 | - $type->validateInput($defaults[$key]); |
|
| 692 | - continue; |
|
| 693 | - } |
|
| 694 | - throw new ValidationException('Missing key: "' . $key . '"'); |
|
| 695 | - } |
|
| 696 | - try { |
|
| 697 | - $type->validateInput($io[$key]); |
|
| 698 | - if ($type === EShapeType::Enum) { |
|
| 699 | - if (!isset($enumValues[$key])) { |
|
| 700 | - throw new ValidationException('Provider did not provide enum values for an enum slot: "' . $key . '"'); |
|
| 701 | - } |
|
| 702 | - $type->validateEnum($io[$key], $enumValues[$key]); |
|
| 703 | - } |
|
| 704 | - } catch (ValidationException $e) { |
|
| 705 | - throw new ValidationException('Failed to validate input key "' . $key . '": ' . $e->getMessage()); |
|
| 706 | - } |
|
| 707 | - } |
|
| 708 | - } |
|
| 709 | - |
|
| 710 | - /** |
|
| 711 | - * Takes task input data and replaces fileIds with File objects |
|
| 712 | - * |
|
| 713 | - * @param array<array-key, list<numeric|string>|numeric|string> $input |
|
| 714 | - * @param array<array-key, numeric|string> ...$defaultSpecs the specs |
|
| 715 | - * @return array<array-key, list<numeric|string>|numeric|string> |
|
| 716 | - */ |
|
| 717 | - public function fillInputDefaults(array $input, ...$defaultSpecs): array { |
|
| 718 | - $spec = array_reduce($defaultSpecs, fn ($carry, $spec) => array_merge($carry, $spec), []); |
|
| 719 | - return array_merge($spec, $input); |
|
| 720 | - } |
|
| 721 | - |
|
| 722 | - /** |
|
| 723 | - * @param ShapeDescriptor[] $spec |
|
| 724 | - * @param array<array-key, ShapeEnumValue[]> $enumValues |
|
| 725 | - * @param array $io |
|
| 726 | - * @param bool $optional |
|
| 727 | - * @return void |
|
| 728 | - * @throws ValidationException |
|
| 729 | - */ |
|
| 730 | - private static function validateOutputWithFileIds(array $spec, array $enumValues, array $io, bool $optional = false): void { |
|
| 731 | - foreach ($spec as $key => $descriptor) { |
|
| 732 | - $type = $descriptor->getShapeType(); |
|
| 733 | - if (!isset($io[$key])) { |
|
| 734 | - if ($optional) { |
|
| 735 | - continue; |
|
| 736 | - } |
|
| 737 | - throw new ValidationException('Missing key: "' . $key . '"'); |
|
| 738 | - } |
|
| 739 | - try { |
|
| 740 | - $type->validateOutputWithFileIds($io[$key]); |
|
| 741 | - if (isset($enumValues[$key])) { |
|
| 742 | - $type->validateEnum($io[$key], $enumValues[$key]); |
|
| 743 | - } |
|
| 744 | - } catch (ValidationException $e) { |
|
| 745 | - throw new ValidationException('Failed to validate output key "' . $key . '": ' . $e->getMessage()); |
|
| 746 | - } |
|
| 747 | - } |
|
| 748 | - } |
|
| 749 | - |
|
| 750 | - /** |
|
| 751 | - * @param ShapeDescriptor[] $spec |
|
| 752 | - * @param array<array-key, ShapeEnumValue[]> $enumValues |
|
| 753 | - * @param array $io |
|
| 754 | - * @param bool $optional |
|
| 755 | - * @return void |
|
| 756 | - * @throws ValidationException |
|
| 757 | - */ |
|
| 758 | - private static function validateOutputWithFileData(array $spec, array $enumValues, array $io, bool $optional = false): void { |
|
| 759 | - foreach ($spec as $key => $descriptor) { |
|
| 760 | - $type = $descriptor->getShapeType(); |
|
| 761 | - if (!isset($io[$key])) { |
|
| 762 | - if ($optional) { |
|
| 763 | - continue; |
|
| 764 | - } |
|
| 765 | - throw new ValidationException('Missing key: "' . $key . '"'); |
|
| 766 | - } |
|
| 767 | - try { |
|
| 768 | - $type->validateOutputWithFileData($io[$key]); |
|
| 769 | - if (isset($enumValues[$key])) { |
|
| 770 | - $type->validateEnum($io[$key], $enumValues[$key]); |
|
| 771 | - } |
|
| 772 | - } catch (ValidationException $e) { |
|
| 773 | - throw new ValidationException('Failed to validate output key "' . $key . '": ' . $e->getMessage()); |
|
| 774 | - } |
|
| 775 | - } |
|
| 776 | - } |
|
| 777 | - |
|
| 778 | - /** |
|
| 779 | - * @param array<array-key, T> $array The array to filter |
|
| 780 | - * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep |
|
| 781 | - * @return array<array-key, T> |
|
| 782 | - * @psalm-template T |
|
| 783 | - */ |
|
| 784 | - private function removeSuperfluousArrayKeys(array $array, ...$specs): array { |
|
| 785 | - $keys = array_unique(array_reduce($specs, fn ($carry, $spec) => array_merge($carry, array_keys($spec)), [])); |
|
| 786 | - $keys = array_filter($keys, fn ($key) => array_key_exists($key, $array)); |
|
| 787 | - $values = array_map(fn (string $key) => $array[$key], $keys); |
|
| 788 | - return array_combine($keys, $values); |
|
| 789 | - } |
|
| 790 | - |
|
| 791 | - public function hasProviders(): bool { |
|
| 792 | - return count($this->getProviders()) !== 0; |
|
| 793 | - } |
|
| 794 | - |
|
| 795 | - public function getProviders(): array { |
|
| 796 | - if ($this->providers === null) { |
|
| 797 | - $this->providers = $this->_getProviders(); |
|
| 798 | - } |
|
| 799 | - |
|
| 800 | - return $this->providers; |
|
| 801 | - } |
|
| 802 | - |
|
| 803 | - public function getPreferredProvider(string $taskTypeId) { |
|
| 804 | - try { |
|
| 805 | - if ($this->preferences === null) { |
|
| 806 | - $this->preferences = $this->distributedCache->get('ai.taskprocessing_provider_preferences'); |
|
| 807 | - if ($this->preferences === null) { |
|
| 808 | - $this->preferences = json_decode( |
|
| 809 | - $this->appConfig->getValueString('core', 'ai.taskprocessing_provider_preferences', 'null', lazy: true), |
|
| 810 | - associative: true, |
|
| 811 | - flags: JSON_THROW_ON_ERROR, |
|
| 812 | - ); |
|
| 813 | - $this->distributedCache->set('ai.taskprocessing_provider_preferences', $this->preferences, 60 * 3); |
|
| 814 | - } |
|
| 815 | - } |
|
| 816 | - |
|
| 817 | - $providers = $this->getProviders(); |
|
| 818 | - if (isset($this->preferences[$taskTypeId])) { |
|
| 819 | - $providersById = $this->providersById ?? array_reduce($providers, static function (array $carry, IProvider $provider) { |
|
| 820 | - $carry[$provider->getId()] = $provider; |
|
| 821 | - return $carry; |
|
| 822 | - }, []); |
|
| 823 | - $this->providersById = $providersById; |
|
| 824 | - if (isset($providersById[$this->preferences[$taskTypeId]])) { |
|
| 825 | - return $providersById[$this->preferences[$taskTypeId]]; |
|
| 826 | - } |
|
| 827 | - } |
|
| 828 | - // By default, use the first available provider |
|
| 829 | - foreach ($providers as $provider) { |
|
| 830 | - if ($provider->getTaskTypeId() === $taskTypeId) { |
|
| 831 | - return $provider; |
|
| 832 | - } |
|
| 833 | - } |
|
| 834 | - } catch (\JsonException $e) { |
|
| 835 | - $this->logger->warning('Failed to parse provider preferences while getting preferred provider for task type ' . $taskTypeId, ['exception' => $e]); |
|
| 836 | - } |
|
| 837 | - throw new \OCP\TaskProcessing\Exception\Exception('No matching provider found'); |
|
| 838 | - } |
|
| 839 | - |
|
| 840 | - public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userId = null): array { |
|
| 841 | - // We cache by language, because some task type fields are translated |
|
| 842 | - $cacheKey = self::TASK_TYPES_CACHE_KEY . ':' . $this->l10nFactory->findLanguage(); |
|
| 843 | - |
|
| 844 | - // userId will be obtained from the session if left to null |
|
| 845 | - if (!$this->checkGuestAccess($userId)) { |
|
| 846 | - return []; |
|
| 847 | - } |
|
| 848 | - if ($this->availableTaskTypes === null) { |
|
| 849 | - $cachedValue = $this->distributedCache->get($cacheKey); |
|
| 850 | - if ($cachedValue !== null) { |
|
| 851 | - $this->availableTaskTypes = unserialize($cachedValue); |
|
| 852 | - } |
|
| 853 | - } |
|
| 854 | - // Either we have no cache or showDisabled is turned on, which we don't want to cache, ever. |
|
| 855 | - if ($this->availableTaskTypes === null || $showDisabled) { |
|
| 856 | - $taskTypes = $this->_getTaskTypes(); |
|
| 857 | - $taskTypeSettings = $this->_getTaskTypeSettings(); |
|
| 858 | - |
|
| 859 | - $availableTaskTypes = []; |
|
| 860 | - foreach ($taskTypes as $taskType) { |
|
| 861 | - if ((!$showDisabled) && isset($taskTypeSettings[$taskType->getId()]) && !$taskTypeSettings[$taskType->getId()]) { |
|
| 862 | - continue; |
|
| 863 | - } |
|
| 864 | - try { |
|
| 865 | - $provider = $this->getPreferredProvider($taskType->getId()); |
|
| 866 | - } catch (\OCP\TaskProcessing\Exception\Exception $e) { |
|
| 867 | - continue; |
|
| 868 | - } |
|
| 869 | - try { |
|
| 870 | - $availableTaskTypes[$provider->getTaskTypeId()] = [ |
|
| 871 | - 'name' => $taskType->getName(), |
|
| 872 | - 'description' => $taskType->getDescription(), |
|
| 873 | - 'optionalInputShape' => $provider->getOptionalInputShape(), |
|
| 874 | - 'inputShapeEnumValues' => $provider->getInputShapeEnumValues(), |
|
| 875 | - 'inputShapeDefaults' => $provider->getInputShapeDefaults(), |
|
| 876 | - 'inputShape' => $taskType->getInputShape(), |
|
| 877 | - 'optionalInputShapeEnumValues' => $provider->getOptionalInputShapeEnumValues(), |
|
| 878 | - 'optionalInputShapeDefaults' => $provider->getOptionalInputShapeDefaults(), |
|
| 879 | - 'outputShape' => $taskType->getOutputShape(), |
|
| 880 | - 'outputShapeEnumValues' => $provider->getOutputShapeEnumValues(), |
|
| 881 | - 'optionalOutputShape' => $provider->getOptionalOutputShape(), |
|
| 882 | - 'optionalOutputShapeEnumValues' => $provider->getOptionalOutputShapeEnumValues(), |
|
| 883 | - 'isInternal' => $taskType instanceof IInternalTaskType, |
|
| 884 | - ]; |
|
| 885 | - } catch (\Throwable $e) { |
|
| 886 | - $this->logger->error('Failed to set up TaskProcessing provider ' . $provider::class, ['exception' => $e]); |
|
| 887 | - } |
|
| 888 | - } |
|
| 889 | - |
|
| 890 | - if ($showDisabled) { |
|
| 891 | - // Do not cache showDisabled, ever. |
|
| 892 | - return $availableTaskTypes; |
|
| 893 | - } |
|
| 894 | - |
|
| 895 | - $this->availableTaskTypes = $availableTaskTypes; |
|
| 896 | - $this->distributedCache->set($cacheKey, serialize($this->availableTaskTypes), 60); |
|
| 897 | - } |
|
| 898 | - |
|
| 899 | - |
|
| 900 | - return $this->availableTaskTypes; |
|
| 901 | - } |
|
| 902 | - public function getAvailableTaskTypeIds(bool $showDisabled = false, ?string $userId = null): array { |
|
| 903 | - // userId will be obtained from the session if left to null |
|
| 904 | - if (!$this->checkGuestAccess($userId)) { |
|
| 905 | - return []; |
|
| 906 | - } |
|
| 907 | - if ($this->availableTaskTypeIds === null) { |
|
| 908 | - $cachedValue = $this->distributedCache->get(self::TASK_TYPE_IDS_CACHE_KEY); |
|
| 909 | - if ($cachedValue !== null) { |
|
| 910 | - $this->availableTaskTypeIds = $cachedValue; |
|
| 911 | - } |
|
| 912 | - } |
|
| 913 | - // Either we have no cache or showDisabled is turned on, which we don't want to cache, ever. |
|
| 914 | - if ($this->availableTaskTypeIds === null || $showDisabled) { |
|
| 915 | - $taskTypes = $this->_getTaskTypes(); |
|
| 916 | - $taskTypeSettings = $this->_getTaskTypeSettings(); |
|
| 917 | - |
|
| 918 | - $availableTaskTypeIds = []; |
|
| 919 | - foreach ($taskTypes as $taskType) { |
|
| 920 | - if ((!$showDisabled) && isset($taskTypeSettings[$taskType->getId()]) && !$taskTypeSettings[$taskType->getId()]) { |
|
| 921 | - continue; |
|
| 922 | - } |
|
| 923 | - try { |
|
| 924 | - $provider = $this->getPreferredProvider($taskType->getId()); |
|
| 925 | - } catch (\OCP\TaskProcessing\Exception\Exception $e) { |
|
| 926 | - continue; |
|
| 927 | - } |
|
| 928 | - $availableTaskTypeIds[] = $taskType->getId(); |
|
| 929 | - } |
|
| 930 | - |
|
| 931 | - if ($showDisabled) { |
|
| 932 | - // Do not cache showDisabled, ever. |
|
| 933 | - return $availableTaskTypeIds; |
|
| 934 | - } |
|
| 935 | - |
|
| 936 | - $this->availableTaskTypeIds = $availableTaskTypeIds; |
|
| 937 | - $this->distributedCache->set(self::TASK_TYPE_IDS_CACHE_KEY, $this->availableTaskTypeIds, 60); |
|
| 938 | - } |
|
| 939 | - |
|
| 940 | - |
|
| 941 | - return $this->availableTaskTypeIds; |
|
| 942 | - } |
|
| 943 | - |
|
| 944 | - public function canHandleTask(Task $task): bool { |
|
| 945 | - return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]); |
|
| 946 | - } |
|
| 947 | - |
|
| 948 | - private function checkGuestAccess(?string $userId = null): bool { |
|
| 949 | - if ($userId === null && !$this->userSession->isLoggedIn()) { |
|
| 950 | - return true; |
|
| 951 | - } |
|
| 952 | - if ($userId === null) { |
|
| 953 | - $user = $this->userSession->getUser(); |
|
| 954 | - } else { |
|
| 955 | - $user = $this->userManager->get($userId); |
|
| 956 | - } |
|
| 957 | - |
|
| 958 | - $guestsAllowed = $this->appConfig->getValueString('core', 'ai.taskprocessing_guests', 'false'); |
|
| 959 | - if ($guestsAllowed == 'true' || !class_exists(\OCA\Guests\UserBackend::class) || !($user->getBackend() instanceof \OCA\Guests\UserBackend)) { |
|
| 960 | - return true; |
|
| 961 | - } |
|
| 962 | - return false; |
|
| 963 | - } |
|
| 964 | - |
|
| 965 | - public function scheduleTask(Task $task): void { |
|
| 966 | - if (!$this->checkGuestAccess($task->getUserId())) { |
|
| 967 | - throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('Access to this resource is forbidden for guests.'); |
|
| 968 | - } |
|
| 969 | - if (!$this->canHandleTask($task)) { |
|
| 970 | - throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId()); |
|
| 971 | - } |
|
| 972 | - $this->prepareTask($task); |
|
| 973 | - $task->setStatus(Task::STATUS_SCHEDULED); |
|
| 974 | - $this->storeTask($task); |
|
| 975 | - // schedule synchronous job if the provider is synchronous |
|
| 976 | - $provider = $this->getPreferredProvider($task->getTaskTypeId()); |
|
| 977 | - if ($provider instanceof ISynchronousProvider) { |
|
| 978 | - $this->jobList->add(SynchronousBackgroundJob::class, null); |
|
| 979 | - } |
|
| 980 | - if ($provider instanceof ITriggerableProvider) { |
|
| 981 | - try { |
|
| 982 | - if (!$this->taskMapper->hasRunningTasksForTaskType($task->getTaskTypeId())) { |
|
| 983 | - // If no tasks are currently running for this task type, nudge the provider to ask for tasks |
|
| 984 | - try { |
|
| 985 | - $provider->trigger(); |
|
| 986 | - } catch (\Throwable $e) { |
|
| 987 | - $this->logger->error('Failed to trigger the provider after scheduling a task.', [ |
|
| 988 | - 'exception' => $e, |
|
| 989 | - 'taskId' => $task->getId(), |
|
| 990 | - 'providerId' => $provider->getId(), |
|
| 991 | - ]); |
|
| 992 | - } |
|
| 993 | - } |
|
| 994 | - } catch (Exception $e) { |
|
| 995 | - $this->logger->error('Failed to check DB for running tasks after a task was scheduled for a triggerable provider. Not triggering the provider.', [ |
|
| 996 | - 'exception' => $e, |
|
| 997 | - 'taskId' => $task->getId(), |
|
| 998 | - 'providerId' => $provider->getId() |
|
| 999 | - ]); |
|
| 1000 | - } |
|
| 1001 | - } |
|
| 1002 | - } |
|
| 1003 | - |
|
| 1004 | - public function runTask(Task $task): Task { |
|
| 1005 | - if (!$this->checkGuestAccess($task->getUserId())) { |
|
| 1006 | - throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('Access to this resource is forbidden for guests.'); |
|
| 1007 | - } |
|
| 1008 | - if (!$this->canHandleTask($task)) { |
|
| 1009 | - throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId()); |
|
| 1010 | - } |
|
| 1011 | - |
|
| 1012 | - $provider = $this->getPreferredProvider($task->getTaskTypeId()); |
|
| 1013 | - if ($provider instanceof ISynchronousProvider) { |
|
| 1014 | - $this->prepareTask($task); |
|
| 1015 | - $task->setStatus(Task::STATUS_SCHEDULED); |
|
| 1016 | - $this->storeTask($task); |
|
| 1017 | - $this->processTask($task, $provider); |
|
| 1018 | - $task = $this->getTask($task->getId()); |
|
| 1019 | - } else { |
|
| 1020 | - $this->scheduleTask($task); |
|
| 1021 | - // poll task |
|
| 1022 | - while ($task->getStatus() === Task::STATUS_SCHEDULED || $task->getStatus() === Task::STATUS_RUNNING) { |
|
| 1023 | - sleep(1); |
|
| 1024 | - $task = $this->getTask($task->getId()); |
|
| 1025 | - } |
|
| 1026 | - } |
|
| 1027 | - return $task; |
|
| 1028 | - } |
|
| 1029 | - |
|
| 1030 | - public function processTask(Task $task, ISynchronousProvider $provider): bool { |
|
| 1031 | - try { |
|
| 1032 | - try { |
|
| 1033 | - $input = $this->prepareInputData($task); |
|
| 1034 | - } catch (GenericFileException|NotPermittedException|LockedException|ValidationException|UnauthorizedException $e) { |
|
| 1035 | - $this->logger->warning('Failed to prepare input data for a TaskProcessing task with synchronous provider ' . $provider->getId(), ['exception' => $e]); |
|
| 1036 | - $this->setTaskResult($task->getId(), $e->getMessage(), null); |
|
| 1037 | - return false; |
|
| 1038 | - } |
|
| 1039 | - try { |
|
| 1040 | - $this->setTaskStatus($task, Task::STATUS_RUNNING); |
|
| 1041 | - $output = $provider->process($task->getUserId(), $input, fn (float $progress) => $this->setTaskProgress($task->getId(), $progress)); |
|
| 1042 | - } catch (ProcessingException $e) { |
|
| 1043 | - $this->logger->warning('Failed to process a TaskProcessing task with synchronous provider ' . $provider->getId(), ['exception' => $e]); |
|
| 1044 | - $this->setTaskResult($task->getId(), $e->getMessage(), null); |
|
| 1045 | - return false; |
|
| 1046 | - } catch (\Throwable $e) { |
|
| 1047 | - $this->logger->error('Unknown error while processing TaskProcessing task', ['exception' => $e]); |
|
| 1048 | - $this->setTaskResult($task->getId(), $e->getMessage(), null); |
|
| 1049 | - return false; |
|
| 1050 | - } |
|
| 1051 | - $this->setTaskResult($task->getId(), null, $output); |
|
| 1052 | - } catch (NotFoundException $e) { |
|
| 1053 | - $this->logger->info('Could not find task anymore after execution. Moving on.', ['exception' => $e]); |
|
| 1054 | - } catch (Exception $e) { |
|
| 1055 | - $this->logger->error('Failed to report result of TaskProcessing task', ['exception' => $e]); |
|
| 1056 | - } |
|
| 1057 | - return true; |
|
| 1058 | - } |
|
| 1059 | - |
|
| 1060 | - public function deleteTask(Task $task): void { |
|
| 1061 | - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1062 | - $this->taskMapper->delete($taskEntity); |
|
| 1063 | - } |
|
| 1064 | - |
|
| 1065 | - public function getTask(int $id): Task { |
|
| 1066 | - try { |
|
| 1067 | - $taskEntity = $this->taskMapper->find($id); |
|
| 1068 | - return $taskEntity->toPublicTask(); |
|
| 1069 | - } catch (DoesNotExistException $e) { |
|
| 1070 | - throw new NotFoundException('Couldn\'t find task with id ' . $id, 0, $e); |
|
| 1071 | - } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) { |
|
| 1072 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1073 | - } catch (\JsonException $e) { |
|
| 1074 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e); |
|
| 1075 | - } |
|
| 1076 | - } |
|
| 1077 | - |
|
| 1078 | - public function cancelTask(int $id): void { |
|
| 1079 | - $task = $this->getTask($id); |
|
| 1080 | - if ($task->getStatus() !== Task::STATUS_SCHEDULED && $task->getStatus() !== Task::STATUS_RUNNING) { |
|
| 1081 | - return; |
|
| 1082 | - } |
|
| 1083 | - $task->setStatus(Task::STATUS_CANCELLED); |
|
| 1084 | - $task->setEndedAt(time()); |
|
| 1085 | - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1086 | - try { |
|
| 1087 | - $this->taskMapper->update($taskEntity); |
|
| 1088 | - $this->runWebhook($task); |
|
| 1089 | - } catch (\OCP\DB\Exception $e) { |
|
| 1090 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1091 | - } |
|
| 1092 | - } |
|
| 1093 | - |
|
| 1094 | - public function setTaskProgress(int $id, float $progress): bool { |
|
| 1095 | - // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently |
|
| 1096 | - $task = $this->getTask($id); |
|
| 1097 | - if ($task->getStatus() === Task::STATUS_CANCELLED) { |
|
| 1098 | - return false; |
|
| 1099 | - } |
|
| 1100 | - // only set the start time if the task is going from scheduled to running |
|
| 1101 | - if ($task->getstatus() === Task::STATUS_SCHEDULED) { |
|
| 1102 | - $task->setStartedAt(time()); |
|
| 1103 | - } |
|
| 1104 | - $task->setStatus(Task::STATUS_RUNNING); |
|
| 1105 | - $task->setProgress($progress); |
|
| 1106 | - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1107 | - try { |
|
| 1108 | - $this->taskMapper->update($taskEntity); |
|
| 1109 | - } catch (\OCP\DB\Exception $e) { |
|
| 1110 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1111 | - } |
|
| 1112 | - return true; |
|
| 1113 | - } |
|
| 1114 | - |
|
| 1115 | - public function setTaskResult(int $id, ?string $error, ?array $result, bool $isUsingFileIds = false): void { |
|
| 1116 | - // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently |
|
| 1117 | - $task = $this->getTask($id); |
|
| 1118 | - if ($task->getStatus() === Task::STATUS_CANCELLED) { |
|
| 1119 | - $this->logger->info('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' finished but was cancelled in the mean time. Moving on without storing result.'); |
|
| 1120 | - return; |
|
| 1121 | - } |
|
| 1122 | - if ($error !== null) { |
|
| 1123 | - $task->setStatus(Task::STATUS_FAILED); |
|
| 1124 | - $task->setEndedAt(time()); |
|
| 1125 | - // truncate error message to 1000 characters |
|
| 1126 | - $task->setErrorMessage(mb_substr($error, 0, 1000)); |
|
| 1127 | - $this->logger->warning('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' failed with the following message: ' . $error); |
|
| 1128 | - } elseif ($result !== null) { |
|
| 1129 | - $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1130 | - $outputShape = $taskTypes[$task->getTaskTypeId()]['outputShape']; |
|
| 1131 | - $outputShapeEnumValues = $taskTypes[$task->getTaskTypeId()]['outputShapeEnumValues']; |
|
| 1132 | - $optionalOutputShape = $taskTypes[$task->getTaskTypeId()]['optionalOutputShape']; |
|
| 1133 | - $optionalOutputShapeEnumValues = $taskTypes[$task->getTaskTypeId()]['optionalOutputShapeEnumValues']; |
|
| 1134 | - try { |
|
| 1135 | - // validate output |
|
| 1136 | - if (!$isUsingFileIds) { |
|
| 1137 | - $this->validateOutputWithFileData($outputShape, $outputShapeEnumValues, $result); |
|
| 1138 | - $this->validateOutputWithFileData($optionalOutputShape, $optionalOutputShapeEnumValues, $result, true); |
|
| 1139 | - } else { |
|
| 1140 | - $this->validateOutputWithFileIds($outputShape, $outputShapeEnumValues, $result); |
|
| 1141 | - $this->validateOutputWithFileIds($optionalOutputShape, $optionalOutputShapeEnumValues, $result, true); |
|
| 1142 | - } |
|
| 1143 | - $output = $this->removeSuperfluousArrayKeys($result, $outputShape, $optionalOutputShape); |
|
| 1144 | - // extract raw data and put it in files, replace it with file ids |
|
| 1145 | - if (!$isUsingFileIds) { |
|
| 1146 | - $output = $this->encapsulateOutputFileData($output, $outputShape, $optionalOutputShape); |
|
| 1147 | - } else { |
|
| 1148 | - $this->validateOutputFileIds($output, $outputShape, $optionalOutputShape); |
|
| 1149 | - } |
|
| 1150 | - // Turn file objects into IDs |
|
| 1151 | - foreach ($output as $key => $value) { |
|
| 1152 | - if ($value instanceof Node) { |
|
| 1153 | - $output[$key] = $value->getId(); |
|
| 1154 | - } |
|
| 1155 | - if (is_array($value) && isset($value[0]) && $value[0] instanceof Node) { |
|
| 1156 | - $output[$key] = array_map(fn ($node) => $node->getId(), $value); |
|
| 1157 | - } |
|
| 1158 | - } |
|
| 1159 | - $task->setOutput($output); |
|
| 1160 | - $task->setProgress(1); |
|
| 1161 | - $task->setStatus(Task::STATUS_SUCCESSFUL); |
|
| 1162 | - $task->setEndedAt(time()); |
|
| 1163 | - } catch (ValidationException $e) { |
|
| 1164 | - $task->setProgress(1); |
|
| 1165 | - $task->setStatus(Task::STATUS_FAILED); |
|
| 1166 | - $task->setEndedAt(time()); |
|
| 1167 | - $error = 'The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec'; |
|
| 1168 | - $task->setErrorMessage($error); |
|
| 1169 | - $this->logger->error($error, ['exception' => $e, 'output' => $result]); |
|
| 1170 | - } catch (NotPermittedException $e) { |
|
| 1171 | - $task->setProgress(1); |
|
| 1172 | - $task->setStatus(Task::STATUS_FAILED); |
|
| 1173 | - $task->setEndedAt(time()); |
|
| 1174 | - $error = 'The task was processed successfully but storing the output in a file failed'; |
|
| 1175 | - $task->setErrorMessage($error); |
|
| 1176 | - $this->logger->error($error, ['exception' => $e]); |
|
| 1177 | - } catch (InvalidPathException|\OCP\Files\NotFoundException $e) { |
|
| 1178 | - $task->setProgress(1); |
|
| 1179 | - $task->setStatus(Task::STATUS_FAILED); |
|
| 1180 | - $task->setEndedAt(time()); |
|
| 1181 | - $error = 'The task was processed successfully but the result file could not be found'; |
|
| 1182 | - $task->setErrorMessage($error); |
|
| 1183 | - $this->logger->error($error, ['exception' => $e]); |
|
| 1184 | - } |
|
| 1185 | - } |
|
| 1186 | - try { |
|
| 1187 | - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1188 | - } catch (\JsonException $e) { |
|
| 1189 | - throw new \OCP\TaskProcessing\Exception\Exception('The task was processed successfully but the provider\'s output could not be encoded as JSON for the database.', 0, $e); |
|
| 1190 | - } |
|
| 1191 | - try { |
|
| 1192 | - $this->taskMapper->update($taskEntity); |
|
| 1193 | - $this->runWebhook($task); |
|
| 1194 | - } catch (\OCP\DB\Exception $e) { |
|
| 1195 | - throw new \OCP\TaskProcessing\Exception\Exception($e->getMessage()); |
|
| 1196 | - } |
|
| 1197 | - if ($task->getStatus() === Task::STATUS_SUCCESSFUL) { |
|
| 1198 | - $event = new TaskSuccessfulEvent($task); |
|
| 1199 | - } else { |
|
| 1200 | - $event = new TaskFailedEvent($task, $error); |
|
| 1201 | - } |
|
| 1202 | - $this->dispatcher->dispatchTyped($event); |
|
| 1203 | - } |
|
| 1204 | - |
|
| 1205 | - public function getNextScheduledTask(array $taskTypeIds = [], array $taskIdsToIgnore = []): Task { |
|
| 1206 | - try { |
|
| 1207 | - $taskEntity = $this->taskMapper->findOldestScheduledByType($taskTypeIds, $taskIdsToIgnore); |
|
| 1208 | - return $taskEntity->toPublicTask(); |
|
| 1209 | - } catch (DoesNotExistException $e) { |
|
| 1210 | - throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', previous: $e); |
|
| 1211 | - } catch (\OCP\DB\Exception $e) { |
|
| 1212 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', previous: $e); |
|
| 1213 | - } catch (\JsonException $e) { |
|
| 1214 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', previous: $e); |
|
| 1215 | - } |
|
| 1216 | - } |
|
| 1217 | - |
|
| 1218 | - public function getNextScheduledTasks(array $taskTypeIds = [], array $taskIdsToIgnore = [], int $numberOfTasks = 1): array { |
|
| 1219 | - try { |
|
| 1220 | - return array_map(fn ($taskEntity) => $taskEntity->toPublicTask(), $this->taskMapper->findNOldestScheduledByType($taskTypeIds, $taskIdsToIgnore, $numberOfTasks)); |
|
| 1221 | - } catch (DoesNotExistException $e) { |
|
| 1222 | - throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', previous: $e); |
|
| 1223 | - } catch (\OCP\DB\Exception $e) { |
|
| 1224 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', previous: $e); |
|
| 1225 | - } catch (\JsonException $e) { |
|
| 1226 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', previous: $e); |
|
| 1227 | - } |
|
| 1228 | - } |
|
| 1229 | - |
|
| 1230 | - /** |
|
| 1231 | - * Takes task input data and replaces fileIds with File objects |
|
| 1232 | - * |
|
| 1233 | - * @param string|null $userId |
|
| 1234 | - * @param array<array-key, list<numeric|string>|numeric|string> $input |
|
| 1235 | - * @param ShapeDescriptor[] ...$specs the specs |
|
| 1236 | - * @return array<array-key, list<File|numeric|string>|numeric|string|File> |
|
| 1237 | - * @throws GenericFileException|LockedException|NotPermittedException|ValidationException|UnauthorizedException |
|
| 1238 | - */ |
|
| 1239 | - public function fillInputFileData(?string $userId, array $input, ...$specs): array { |
|
| 1240 | - if ($userId !== null) { |
|
| 1241 | - \OC_Util::setupFS($userId); |
|
| 1242 | - } |
|
| 1243 | - $newInputOutput = []; |
|
| 1244 | - $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []); |
|
| 1245 | - foreach ($spec as $key => $descriptor) { |
|
| 1246 | - $type = $descriptor->getShapeType(); |
|
| 1247 | - if (!isset($input[$key])) { |
|
| 1248 | - continue; |
|
| 1249 | - } |
|
| 1250 | - if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) { |
|
| 1251 | - $newInputOutput[$key] = $input[$key]; |
|
| 1252 | - continue; |
|
| 1253 | - } |
|
| 1254 | - if (EShapeType::getScalarType($type) === $type) { |
|
| 1255 | - // is scalar |
|
| 1256 | - $node = $this->validateFileId((int)$input[$key]); |
|
| 1257 | - $this->validateUserAccessToFile($input[$key], $userId); |
|
| 1258 | - $newInputOutput[$key] = $node; |
|
| 1259 | - } else { |
|
| 1260 | - // is list |
|
| 1261 | - $newInputOutput[$key] = []; |
|
| 1262 | - foreach ($input[$key] as $item) { |
|
| 1263 | - $node = $this->validateFileId((int)$item); |
|
| 1264 | - $this->validateUserAccessToFile($item, $userId); |
|
| 1265 | - $newInputOutput[$key][] = $node; |
|
| 1266 | - } |
|
| 1267 | - } |
|
| 1268 | - } |
|
| 1269 | - return $newInputOutput; |
|
| 1270 | - } |
|
| 1271 | - |
|
| 1272 | - public function getUserTask(int $id, ?string $userId): Task { |
|
| 1273 | - try { |
|
| 1274 | - $taskEntity = $this->taskMapper->findByIdAndUser($id, $userId); |
|
| 1275 | - return $taskEntity->toPublicTask(); |
|
| 1276 | - } catch (DoesNotExistException $e) { |
|
| 1277 | - throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', 0, $e); |
|
| 1278 | - } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) { |
|
| 1279 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1280 | - } catch (\JsonException $e) { |
|
| 1281 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e); |
|
| 1282 | - } |
|
| 1283 | - } |
|
| 1284 | - |
|
| 1285 | - public function getUserTasks(?string $userId, ?string $taskTypeId = null, ?string $customId = null): array { |
|
| 1286 | - try { |
|
| 1287 | - $taskEntities = $this->taskMapper->findByUserAndTaskType($userId, $taskTypeId, $customId); |
|
| 1288 | - return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities); |
|
| 1289 | - } catch (\OCP\DB\Exception $e) { |
|
| 1290 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the tasks', 0, $e); |
|
| 1291 | - } catch (\JsonException $e) { |
|
| 1292 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the tasks', 0, $e); |
|
| 1293 | - } |
|
| 1294 | - } |
|
| 1295 | - |
|
| 1296 | - public function getTasks( |
|
| 1297 | - ?string $userId, ?string $taskTypeId = null, ?string $appId = null, ?string $customId = null, |
|
| 1298 | - ?int $status = null, ?int $scheduleAfter = null, ?int $endedBefore = null, |
|
| 1299 | - ): array { |
|
| 1300 | - try { |
|
| 1301 | - $taskEntities = $this->taskMapper->findTasks($userId, $taskTypeId, $appId, $customId, $status, $scheduleAfter, $endedBefore); |
|
| 1302 | - return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities); |
|
| 1303 | - } catch (\OCP\DB\Exception $e) { |
|
| 1304 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the tasks', 0, $e); |
|
| 1305 | - } catch (\JsonException $e) { |
|
| 1306 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the tasks', 0, $e); |
|
| 1307 | - } |
|
| 1308 | - } |
|
| 1309 | - |
|
| 1310 | - public function getUserTasksByApp(?string $userId, string $appId, ?string $customId = null): array { |
|
| 1311 | - try { |
|
| 1312 | - $taskEntities = $this->taskMapper->findUserTasksByApp($userId, $appId, $customId); |
|
| 1313 | - return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities); |
|
| 1314 | - } catch (\OCP\DB\Exception $e) { |
|
| 1315 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding a task', 0, $e); |
|
| 1316 | - } catch (\JsonException $e) { |
|
| 1317 | - throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding a task', 0, $e); |
|
| 1318 | - } |
|
| 1319 | - } |
|
| 1320 | - |
|
| 1321 | - /** |
|
| 1322 | - *Takes task input or output and replaces base64 data with file ids |
|
| 1323 | - * |
|
| 1324 | - * @param array $output |
|
| 1325 | - * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep |
|
| 1326 | - * @return array |
|
| 1327 | - * @throws NotPermittedException |
|
| 1328 | - */ |
|
| 1329 | - public function encapsulateOutputFileData(array $output, ...$specs): array { |
|
| 1330 | - $newOutput = []; |
|
| 1331 | - try { |
|
| 1332 | - $folder = $this->appData->getFolder('TaskProcessing'); |
|
| 1333 | - } catch (\OCP\Files\NotFoundException) { |
|
| 1334 | - $folder = $this->appData->newFolder('TaskProcessing'); |
|
| 1335 | - } |
|
| 1336 | - $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []); |
|
| 1337 | - foreach ($spec as $key => $descriptor) { |
|
| 1338 | - $type = $descriptor->getShapeType(); |
|
| 1339 | - if (!isset($output[$key])) { |
|
| 1340 | - continue; |
|
| 1341 | - } |
|
| 1342 | - if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) { |
|
| 1343 | - $newOutput[$key] = $output[$key]; |
|
| 1344 | - continue; |
|
| 1345 | - } |
|
| 1346 | - if (EShapeType::getScalarType($type) === $type) { |
|
| 1347 | - /** @var SimpleFile $file */ |
|
| 1348 | - $file = $folder->newFile(time() . '-' . rand(1, 100000), $output[$key]); |
|
| 1349 | - $newOutput[$key] = $file->getId(); // polymorphic call to SimpleFile |
|
| 1350 | - } else { |
|
| 1351 | - $newOutput = []; |
|
| 1352 | - foreach ($output[$key] as $item) { |
|
| 1353 | - /** @var SimpleFile $file */ |
|
| 1354 | - $file = $folder->newFile(time() . '-' . rand(1, 100000), $item); |
|
| 1355 | - $newOutput[$key][] = $file->getId(); |
|
| 1356 | - } |
|
| 1357 | - } |
|
| 1358 | - } |
|
| 1359 | - return $newOutput; |
|
| 1360 | - } |
|
| 1361 | - |
|
| 1362 | - /** |
|
| 1363 | - * @param Task $task |
|
| 1364 | - * @return array<array-key, list<numeric|string|File>|numeric|string|File> |
|
| 1365 | - * @throws GenericFileException |
|
| 1366 | - * @throws LockedException |
|
| 1367 | - * @throws NotPermittedException |
|
| 1368 | - * @throws ValidationException|UnauthorizedException |
|
| 1369 | - */ |
|
| 1370 | - public function prepareInputData(Task $task): array { |
|
| 1371 | - $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1372 | - $inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape']; |
|
| 1373 | - $optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape']; |
|
| 1374 | - $input = $task->getInput(); |
|
| 1375 | - $input = $this->removeSuperfluousArrayKeys($input, $inputShape, $optionalInputShape); |
|
| 1376 | - $input = $this->fillInputFileData($task->getUserId(), $input, $inputShape, $optionalInputShape); |
|
| 1377 | - return $input; |
|
| 1378 | - } |
|
| 1379 | - |
|
| 1380 | - public function lockTask(Task $task): bool { |
|
| 1381 | - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1382 | - if ($this->taskMapper->lockTask($taskEntity) === 0) { |
|
| 1383 | - return false; |
|
| 1384 | - } |
|
| 1385 | - $task->setStatus(Task::STATUS_RUNNING); |
|
| 1386 | - return true; |
|
| 1387 | - } |
|
| 1388 | - |
|
| 1389 | - /** |
|
| 1390 | - * @throws \JsonException |
|
| 1391 | - * @throws Exception |
|
| 1392 | - */ |
|
| 1393 | - public function setTaskStatus(Task $task, int $status): void { |
|
| 1394 | - $currentTaskStatus = $task->getStatus(); |
|
| 1395 | - if ($currentTaskStatus === Task::STATUS_SCHEDULED && $status === Task::STATUS_RUNNING) { |
|
| 1396 | - $task->setStartedAt(time()); |
|
| 1397 | - } elseif ($currentTaskStatus === Task::STATUS_RUNNING && ($status === Task::STATUS_FAILED || $status === Task::STATUS_CANCELLED)) { |
|
| 1398 | - $task->setEndedAt(time()); |
|
| 1399 | - } elseif ($currentTaskStatus === Task::STATUS_UNKNOWN && $status === Task::STATUS_SCHEDULED) { |
|
| 1400 | - $task->setScheduledAt(time()); |
|
| 1401 | - } |
|
| 1402 | - $task->setStatus($status); |
|
| 1403 | - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1404 | - $this->taskMapper->update($taskEntity); |
|
| 1405 | - } |
|
| 1406 | - |
|
| 1407 | - /** |
|
| 1408 | - * Validate input, fill input default values, set completionExpectedAt, set scheduledAt |
|
| 1409 | - * |
|
| 1410 | - * @param Task $task |
|
| 1411 | - * @return void |
|
| 1412 | - * @throws UnauthorizedException |
|
| 1413 | - * @throws ValidationException |
|
| 1414 | - * @throws \OCP\TaskProcessing\Exception\Exception |
|
| 1415 | - */ |
|
| 1416 | - private function prepareTask(Task $task): void { |
|
| 1417 | - $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1418 | - $taskType = $taskTypes[$task->getTaskTypeId()]; |
|
| 1419 | - $inputShape = $taskType['inputShape']; |
|
| 1420 | - $inputShapeDefaults = $taskType['inputShapeDefaults']; |
|
| 1421 | - $inputShapeEnumValues = $taskType['inputShapeEnumValues']; |
|
| 1422 | - $optionalInputShape = $taskType['optionalInputShape']; |
|
| 1423 | - $optionalInputShapeEnumValues = $taskType['optionalInputShapeEnumValues']; |
|
| 1424 | - $optionalInputShapeDefaults = $taskType['optionalInputShapeDefaults']; |
|
| 1425 | - // validate input |
|
| 1426 | - $this->validateInput($inputShape, $inputShapeDefaults, $inputShapeEnumValues, $task->getInput()); |
|
| 1427 | - $this->validateInput($optionalInputShape, $optionalInputShapeDefaults, $optionalInputShapeEnumValues, $task->getInput(), true); |
|
| 1428 | - // authenticate access to mentioned files |
|
| 1429 | - $ids = []; |
|
| 1430 | - foreach ($inputShape + $optionalInputShape as $key => $descriptor) { |
|
| 1431 | - if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) { |
|
| 1432 | - /** @var list<int>|int $inputSlot */ |
|
| 1433 | - $inputSlot = $task->getInput()[$key]; |
|
| 1434 | - if (is_array($inputSlot)) { |
|
| 1435 | - $ids += $inputSlot; |
|
| 1436 | - } else { |
|
| 1437 | - $ids[] = $inputSlot; |
|
| 1438 | - } |
|
| 1439 | - } |
|
| 1440 | - } |
|
| 1441 | - foreach ($ids as $fileId) { |
|
| 1442 | - $this->validateFileId($fileId); |
|
| 1443 | - $this->validateUserAccessToFile($fileId, $task->getUserId()); |
|
| 1444 | - } |
|
| 1445 | - // remove superfluous keys and set input |
|
| 1446 | - $input = $this->removeSuperfluousArrayKeys($task->getInput(), $inputShape, $optionalInputShape); |
|
| 1447 | - $inputWithDefaults = $this->fillInputDefaults($input, $inputShapeDefaults, $optionalInputShapeDefaults); |
|
| 1448 | - $task->setInput($inputWithDefaults); |
|
| 1449 | - $task->setScheduledAt(time()); |
|
| 1450 | - $provider = $this->getPreferredProvider($task->getTaskTypeId()); |
|
| 1451 | - // calculate expected completion time |
|
| 1452 | - $completionExpectedAt = new \DateTime('now'); |
|
| 1453 | - $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S')); |
|
| 1454 | - $task->setCompletionExpectedAt($completionExpectedAt); |
|
| 1455 | - } |
|
| 1456 | - |
|
| 1457 | - /** |
|
| 1458 | - * Store the task in the DB and set its ID in the \OCP\TaskProcessing\Task input param |
|
| 1459 | - * |
|
| 1460 | - * @param Task $task |
|
| 1461 | - * @return void |
|
| 1462 | - * @throws Exception |
|
| 1463 | - * @throws \JsonException |
|
| 1464 | - */ |
|
| 1465 | - private function storeTask(Task $task): void { |
|
| 1466 | - // create a db entity and insert into db table |
|
| 1467 | - $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1468 | - $this->taskMapper->insert($taskEntity); |
|
| 1469 | - // make sure the scheduler knows the id |
|
| 1470 | - $task->setId($taskEntity->getId()); |
|
| 1471 | - } |
|
| 1472 | - |
|
| 1473 | - /** |
|
| 1474 | - * @param array $output |
|
| 1475 | - * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep |
|
| 1476 | - * @return array |
|
| 1477 | - * @throws NotPermittedException |
|
| 1478 | - */ |
|
| 1479 | - private function validateOutputFileIds(array $output, ...$specs): array { |
|
| 1480 | - $newOutput = []; |
|
| 1481 | - $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []); |
|
| 1482 | - foreach ($spec as $key => $descriptor) { |
|
| 1483 | - $type = $descriptor->getShapeType(); |
|
| 1484 | - if (!isset($output[$key])) { |
|
| 1485 | - continue; |
|
| 1486 | - } |
|
| 1487 | - if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) { |
|
| 1488 | - $newOutput[$key] = $output[$key]; |
|
| 1489 | - continue; |
|
| 1490 | - } |
|
| 1491 | - if (EShapeType::getScalarType($type) === $type) { |
|
| 1492 | - // Is scalar file ID |
|
| 1493 | - $newOutput[$key] = $this->validateFileId($output[$key]); |
|
| 1494 | - } else { |
|
| 1495 | - // Is list of file IDs |
|
| 1496 | - $newOutput = []; |
|
| 1497 | - foreach ($output[$key] as $item) { |
|
| 1498 | - $newOutput[$key][] = $this->validateFileId($item); |
|
| 1499 | - } |
|
| 1500 | - } |
|
| 1501 | - } |
|
| 1502 | - return $newOutput; |
|
| 1503 | - } |
|
| 1504 | - |
|
| 1505 | - /** |
|
| 1506 | - * @param mixed $id |
|
| 1507 | - * @return File |
|
| 1508 | - * @throws ValidationException |
|
| 1509 | - */ |
|
| 1510 | - private function validateFileId(mixed $id): File { |
|
| 1511 | - $node = $this->rootFolder->getFirstNodeById($id); |
|
| 1512 | - if ($node === null) { |
|
| 1513 | - $node = $this->rootFolder->getFirstNodeByIdInPath($id, '/' . $this->rootFolder->getAppDataDirectoryName() . '/'); |
|
| 1514 | - if ($node === null) { |
|
| 1515 | - throw new ValidationException('Could not find file ' . $id); |
|
| 1516 | - } elseif (!$node instanceof File) { |
|
| 1517 | - throw new ValidationException('File with id "' . $id . '" is not a file'); |
|
| 1518 | - } |
|
| 1519 | - } elseif (!$node instanceof File) { |
|
| 1520 | - throw new ValidationException('File with id "' . $id . '" is not a file'); |
|
| 1521 | - } |
|
| 1522 | - return $node; |
|
| 1523 | - } |
|
| 1524 | - |
|
| 1525 | - /** |
|
| 1526 | - * @param mixed $fileId |
|
| 1527 | - * @param string|null $userId |
|
| 1528 | - * @return void |
|
| 1529 | - * @throws UnauthorizedException |
|
| 1530 | - */ |
|
| 1531 | - private function validateUserAccessToFile(mixed $fileId, ?string $userId): void { |
|
| 1532 | - if ($userId === null) { |
|
| 1533 | - throw new UnauthorizedException('User does not have access to file ' . $fileId); |
|
| 1534 | - } |
|
| 1535 | - $mounts = $this->userMountCache->getMountsForFileId($fileId); |
|
| 1536 | - $userIds = array_map(fn ($mount) => $mount->getUser()->getUID(), $mounts); |
|
| 1537 | - if (!in_array($userId, $userIds)) { |
|
| 1538 | - throw new UnauthorizedException('User ' . $userId . ' does not have access to file ' . $fileId); |
|
| 1539 | - } |
|
| 1540 | - } |
|
| 1541 | - |
|
| 1542 | - /** |
|
| 1543 | - * @param Task $task |
|
| 1544 | - * @return list<int> |
|
| 1545 | - * @throws NotFoundException |
|
| 1546 | - */ |
|
| 1547 | - public function extractFileIdsFromTask(Task $task): array { |
|
| 1548 | - $ids = []; |
|
| 1549 | - $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1550 | - if (!isset($taskTypes[$task->getTaskTypeId()])) { |
|
| 1551 | - throw new NotFoundException('Could not find task type'); |
|
| 1552 | - } |
|
| 1553 | - $taskType = $taskTypes[$task->getTaskTypeId()]; |
|
| 1554 | - foreach ($taskType['inputShape'] + $taskType['optionalInputShape'] as $key => $descriptor) { |
|
| 1555 | - if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) { |
|
| 1556 | - /** @var int|list<int> $inputSlot */ |
|
| 1557 | - $inputSlot = $task->getInput()[$key]; |
|
| 1558 | - if (is_array($inputSlot)) { |
|
| 1559 | - $ids = array_merge($inputSlot, $ids); |
|
| 1560 | - } else { |
|
| 1561 | - $ids[] = $inputSlot; |
|
| 1562 | - } |
|
| 1563 | - } |
|
| 1564 | - } |
|
| 1565 | - if ($task->getOutput() !== null) { |
|
| 1566 | - foreach ($taskType['outputShape'] + $taskType['optionalOutputShape'] as $key => $descriptor) { |
|
| 1567 | - if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) { |
|
| 1568 | - /** @var int|list<int> $outputSlot */ |
|
| 1569 | - $outputSlot = $task->getOutput()[$key]; |
|
| 1570 | - if (is_array($outputSlot)) { |
|
| 1571 | - $ids = array_merge($outputSlot, $ids); |
|
| 1572 | - } else { |
|
| 1573 | - $ids[] = $outputSlot; |
|
| 1574 | - } |
|
| 1575 | - } |
|
| 1576 | - } |
|
| 1577 | - } |
|
| 1578 | - return $ids; |
|
| 1579 | - } |
|
| 1580 | - |
|
| 1581 | - /** |
|
| 1582 | - * @param ISimpleFolder $folder |
|
| 1583 | - * @param int $ageInSeconds |
|
| 1584 | - * @return \Generator |
|
| 1585 | - */ |
|
| 1586 | - public function clearFilesOlderThan(ISimpleFolder $folder, int $ageInSeconds = self::MAX_TASK_AGE_SECONDS): \Generator { |
|
| 1587 | - foreach ($folder->getDirectoryListing() as $file) { |
|
| 1588 | - if ($file->getMTime() < time() - $ageInSeconds) { |
|
| 1589 | - try { |
|
| 1590 | - $fileName = $file->getName(); |
|
| 1591 | - $file->delete(); |
|
| 1592 | - yield $fileName; |
|
| 1593 | - } catch (NotPermittedException $e) { |
|
| 1594 | - $this->logger->warning('Failed to delete a stale task processing file', ['exception' => $e]); |
|
| 1595 | - } |
|
| 1596 | - } |
|
| 1597 | - } |
|
| 1598 | - } |
|
| 1599 | - |
|
| 1600 | - /** |
|
| 1601 | - * @param int $ageInSeconds |
|
| 1602 | - * @return \Generator |
|
| 1603 | - * @throws Exception |
|
| 1604 | - * @throws InvalidPathException |
|
| 1605 | - * @throws NotFoundException |
|
| 1606 | - * @throws \JsonException |
|
| 1607 | - * @throws \OCP\Files\NotFoundException |
|
| 1608 | - */ |
|
| 1609 | - public function cleanupTaskProcessingTaskFiles(int $ageInSeconds = self::MAX_TASK_AGE_SECONDS): \Generator { |
|
| 1610 | - $taskIdsToCleanup = []; |
|
| 1611 | - foreach ($this->taskMapper->getTasksToCleanup($ageInSeconds) as $task) { |
|
| 1612 | - $taskIdsToCleanup[] = $task->getId(); |
|
| 1613 | - $ocpTask = $task->toPublicTask(); |
|
| 1614 | - $fileIds = $this->extractFileIdsFromTask($ocpTask); |
|
| 1615 | - foreach ($fileIds as $fileId) { |
|
| 1616 | - // only look for output files stored in appData/TaskProcessing/ |
|
| 1617 | - $file = $this->rootFolder->getFirstNodeByIdInPath($fileId, '/' . $this->rootFolder->getAppDataDirectoryName() . '/core/TaskProcessing/'); |
|
| 1618 | - if ($file instanceof File) { |
|
| 1619 | - try { |
|
| 1620 | - $fileId = $file->getId(); |
|
| 1621 | - $fileName = $file->getName(); |
|
| 1622 | - $file->delete(); |
|
| 1623 | - yield ['task_id' => $task->getId(), 'file_id' => $fileId, 'file_name' => $fileName]; |
|
| 1624 | - } catch (NotPermittedException $e) { |
|
| 1625 | - $this->logger->warning('Failed to delete a stale task processing file', ['exception' => $e]); |
|
| 1626 | - } |
|
| 1627 | - } |
|
| 1628 | - } |
|
| 1629 | - } |
|
| 1630 | - return $taskIdsToCleanup; |
|
| 1631 | - } |
|
| 1632 | - |
|
| 1633 | - /** |
|
| 1634 | - * Make a request to the task's webhookUri if necessary |
|
| 1635 | - * |
|
| 1636 | - * @param Task $task |
|
| 1637 | - */ |
|
| 1638 | - private function runWebhook(Task $task): void { |
|
| 1639 | - $uri = $task->getWebhookUri(); |
|
| 1640 | - $method = $task->getWebhookMethod(); |
|
| 1641 | - |
|
| 1642 | - if (!$uri || !$method) { |
|
| 1643 | - return; |
|
| 1644 | - } |
|
| 1645 | - |
|
| 1646 | - if (in_array($method, ['HTTP:GET', 'HTTP:POST', 'HTTP:PUT', 'HTTP:DELETE'], true)) { |
|
| 1647 | - $client = $this->clientService->newClient(); |
|
| 1648 | - $httpMethod = preg_replace('/^HTTP:/', '', $method); |
|
| 1649 | - $options = [ |
|
| 1650 | - 'timeout' => 30, |
|
| 1651 | - 'body' => json_encode([ |
|
| 1652 | - 'task' => $task->jsonSerialize(), |
|
| 1653 | - ]), |
|
| 1654 | - 'headers' => ['Content-Type' => 'application/json'], |
|
| 1655 | - ]; |
|
| 1656 | - try { |
|
| 1657 | - $client->request($httpMethod, $uri, $options); |
|
| 1658 | - } catch (ClientException|ServerException $e) { |
|
| 1659 | - $this->logger->warning('Task processing HTTP webhook failed for task ' . $task->getId() . '. Request failed', ['exception' => $e]); |
|
| 1660 | - } catch (\Exception|\Throwable $e) { |
|
| 1661 | - $this->logger->warning('Task processing HTTP webhook failed for task ' . $task->getId() . '. Unknown error', ['exception' => $e]); |
|
| 1662 | - } |
|
| 1663 | - } elseif (str_starts_with($method, 'AppAPI:') && str_starts_with($uri, '/')) { |
|
| 1664 | - $parsedMethod = explode(':', $method, 4); |
|
| 1665 | - if (count($parsedMethod) < 3) { |
|
| 1666 | - $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. Invalid method: ' . $method); |
|
| 1667 | - } |
|
| 1668 | - [, $exAppId, $httpMethod] = $parsedMethod; |
|
| 1669 | - if (!$this->appManager->isEnabledForAnyone('app_api')) { |
|
| 1670 | - $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. AppAPI is disabled or not installed.'); |
|
| 1671 | - return; |
|
| 1672 | - } |
|
| 1673 | - try { |
|
| 1674 | - $appApiFunctions = \OCP\Server::get(\OCA\AppAPI\PublicFunctions::class); |
|
| 1675 | - } catch (ContainerExceptionInterface|NotFoundExceptionInterface) { |
|
| 1676 | - $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. Could not get AppAPI public functions.'); |
|
| 1677 | - return; |
|
| 1678 | - } |
|
| 1679 | - $exApp = $appApiFunctions->getExApp($exAppId); |
|
| 1680 | - if ($exApp === null) { |
|
| 1681 | - $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. ExApp ' . $exAppId . ' is missing.'); |
|
| 1682 | - return; |
|
| 1683 | - } elseif (!$exApp['enabled']) { |
|
| 1684 | - $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. ExApp ' . $exAppId . ' is disabled.'); |
|
| 1685 | - return; |
|
| 1686 | - } |
|
| 1687 | - $requestParams = [ |
|
| 1688 | - 'task' => $task->jsonSerialize(), |
|
| 1689 | - ]; |
|
| 1690 | - $requestOptions = [ |
|
| 1691 | - 'timeout' => 30, |
|
| 1692 | - ]; |
|
| 1693 | - $response = $appApiFunctions->exAppRequest($exAppId, $uri, $task->getUserId(), $httpMethod, $requestParams, $requestOptions); |
|
| 1694 | - if (is_array($response) && isset($response['error'])) { |
|
| 1695 | - $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. Error during request to ExApp(' . $exAppId . '): ', $response['error']); |
|
| 1696 | - } |
|
| 1697 | - } |
|
| 1698 | - } |
|
| 75 | + public const LEGACY_PREFIX_TEXTPROCESSING = 'legacy:TextProcessing:'; |
|
| 76 | + public const LEGACY_PREFIX_TEXTTOIMAGE = 'legacy:TextToImage:'; |
|
| 77 | + public const LEGACY_PREFIX_SPEECHTOTEXT = 'legacy:SpeechToText:'; |
|
| 78 | + |
|
| 79 | + public const LAZY_CONFIG_KEYS = [ |
|
| 80 | + 'ai.taskprocessing_type_preferences', |
|
| 81 | + 'ai.taskprocessing_provider_preferences', |
|
| 82 | + ]; |
|
| 83 | + |
|
| 84 | + public const MAX_TASK_AGE_SECONDS = 60 * 60 * 24 * 31 * 6; // 6 months |
|
| 85 | + |
|
| 86 | + private const TASK_TYPES_CACHE_KEY = 'available_task_types_v3'; |
|
| 87 | + private const TASK_TYPE_IDS_CACHE_KEY = 'available_task_type_ids'; |
|
| 88 | + |
|
| 89 | + /** @var list<IProvider>|null */ |
|
| 90 | + private ?array $providers = null; |
|
| 91 | + |
|
| 92 | + /** |
|
| 93 | + * @var array<array-key,array{name: string, description: string, inputShape: ShapeDescriptor[], inputShapeEnumValues: ShapeEnumValue[][], inputShapeDefaults: array<array-key, numeric|string>, isInternal: bool, optionalInputShape: ShapeDescriptor[], optionalInputShapeEnumValues: ShapeEnumValue[][], optionalInputShapeDefaults: array<array-key, numeric|string>, outputShape: ShapeDescriptor[], outputShapeEnumValues: ShapeEnumValue[][], optionalOutputShape: ShapeDescriptor[], optionalOutputShapeEnumValues: ShapeEnumValue[][]}> |
|
| 94 | + */ |
|
| 95 | + private ?array $availableTaskTypes = null; |
|
| 96 | + |
|
| 97 | + /** @var list<string>|null */ |
|
| 98 | + private ?array $availableTaskTypeIds = null; |
|
| 99 | + |
|
| 100 | + private IAppData $appData; |
|
| 101 | + private ?array $preferences = null; |
|
| 102 | + private ?array $providersById = null; |
|
| 103 | + |
|
| 104 | + /** @var ITaskType[]|null */ |
|
| 105 | + private ?array $taskTypes = null; |
|
| 106 | + private ICache $distributedCache; |
|
| 107 | + |
|
| 108 | + private ?GetTaskProcessingProvidersEvent $eventResult = null; |
|
| 109 | + |
|
| 110 | + public function __construct( |
|
| 111 | + private IAppConfig $appConfig, |
|
| 112 | + private Coordinator $coordinator, |
|
| 113 | + private IServerContainer $serverContainer, |
|
| 114 | + private LoggerInterface $logger, |
|
| 115 | + private TaskMapper $taskMapper, |
|
| 116 | + private IJobList $jobList, |
|
| 117 | + private IEventDispatcher $dispatcher, |
|
| 118 | + IAppDataFactory $appDataFactory, |
|
| 119 | + private IRootFolder $rootFolder, |
|
| 120 | + private \OCP\TextToImage\IManager $textToImageManager, |
|
| 121 | + private IUserMountCache $userMountCache, |
|
| 122 | + private IClientService $clientService, |
|
| 123 | + private IAppManager $appManager, |
|
| 124 | + private IUserManager $userManager, |
|
| 125 | + private IUserSession $userSession, |
|
| 126 | + ICacheFactory $cacheFactory, |
|
| 127 | + private IFactory $l10nFactory, |
|
| 128 | + ) { |
|
| 129 | + $this->appData = $appDataFactory->get('core'); |
|
| 130 | + $this->distributedCache = $cacheFactory->createDistributed('task_processing::'); |
|
| 131 | + } |
|
| 132 | + |
|
| 133 | + |
|
| 134 | + /** |
|
| 135 | + * This is almost a copy of textProcessingManager->getProviders |
|
| 136 | + * to avoid a dependency cycle between TextProcessingManager and TaskProcessingManager |
|
| 137 | + */ |
|
| 138 | + private function _getRawTextProcessingProviders(): array { |
|
| 139 | + $context = $this->coordinator->getRegistrationContext(); |
|
| 140 | + if ($context === null) { |
|
| 141 | + return []; |
|
| 142 | + } |
|
| 143 | + |
|
| 144 | + $providers = []; |
|
| 145 | + |
|
| 146 | + foreach ($context->getTextProcessingProviders() as $providerServiceRegistration) { |
|
| 147 | + $class = $providerServiceRegistration->getService(); |
|
| 148 | + try { |
|
| 149 | + $providers[$class] = $this->serverContainer->get($class); |
|
| 150 | + } catch (\Throwable $e) { |
|
| 151 | + $this->logger->error('Failed to load Text processing provider ' . $class, [ |
|
| 152 | + 'exception' => $e, |
|
| 153 | + ]); |
|
| 154 | + } |
|
| 155 | + } |
|
| 156 | + |
|
| 157 | + return $providers; |
|
| 158 | + } |
|
| 159 | + |
|
| 160 | + private function _getTextProcessingProviders(): array { |
|
| 161 | + $oldProviders = $this->_getRawTextProcessingProviders(); |
|
| 162 | + $newProviders = []; |
|
| 163 | + foreach ($oldProviders as $oldProvider) { |
|
| 164 | + $provider = new class($oldProvider) implements IProvider, ISynchronousProvider { |
|
| 165 | + private \OCP\TextProcessing\IProvider $provider; |
|
| 166 | + |
|
| 167 | + public function __construct(\OCP\TextProcessing\IProvider $provider) { |
|
| 168 | + $this->provider = $provider; |
|
| 169 | + } |
|
| 170 | + |
|
| 171 | + public function getId(): string { |
|
| 172 | + if ($this->provider instanceof \OCP\TextProcessing\IProviderWithId) { |
|
| 173 | + return $this->provider->getId(); |
|
| 174 | + } |
|
| 175 | + return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider::class; |
|
| 176 | + } |
|
| 177 | + |
|
| 178 | + public function getName(): string { |
|
| 179 | + return $this->provider->getName(); |
|
| 180 | + } |
|
| 181 | + |
|
| 182 | + public function getTaskTypeId(): string { |
|
| 183 | + return match ($this->provider->getTaskType()) { |
|
| 184 | + \OCP\TextProcessing\FreePromptTaskType::class => TextToText::ID, |
|
| 185 | + \OCP\TextProcessing\HeadlineTaskType::class => TextToTextHeadline::ID, |
|
| 186 | + \OCP\TextProcessing\TopicsTaskType::class => TextToTextTopics::ID, |
|
| 187 | + \OCP\TextProcessing\SummaryTaskType::class => TextToTextSummary::ID, |
|
| 188 | + default => Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->provider->getTaskType(), |
|
| 189 | + }; |
|
| 190 | + } |
|
| 191 | + |
|
| 192 | + public function getExpectedRuntime(): int { |
|
| 193 | + if ($this->provider instanceof \OCP\TextProcessing\IProviderWithExpectedRuntime) { |
|
| 194 | + return $this->provider->getExpectedRuntime(); |
|
| 195 | + } |
|
| 196 | + return 60; |
|
| 197 | + } |
|
| 198 | + |
|
| 199 | + public function getOptionalInputShape(): array { |
|
| 200 | + return []; |
|
| 201 | + } |
|
| 202 | + |
|
| 203 | + public function getOptionalOutputShape(): array { |
|
| 204 | + return []; |
|
| 205 | + } |
|
| 206 | + |
|
| 207 | + public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 208 | + if ($this->provider instanceof \OCP\TextProcessing\IProviderWithUserId) { |
|
| 209 | + $this->provider->setUserId($userId); |
|
| 210 | + } |
|
| 211 | + try { |
|
| 212 | + return ['output' => $this->provider->process($input['input'])]; |
|
| 213 | + } catch (\RuntimeException $e) { |
|
| 214 | + throw new ProcessingException($e->getMessage(), 0, $e); |
|
| 215 | + } |
|
| 216 | + } |
|
| 217 | + |
|
| 218 | + public function getInputShapeEnumValues(): array { |
|
| 219 | + return []; |
|
| 220 | + } |
|
| 221 | + |
|
| 222 | + public function getInputShapeDefaults(): array { |
|
| 223 | + return []; |
|
| 224 | + } |
|
| 225 | + |
|
| 226 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 227 | + return []; |
|
| 228 | + } |
|
| 229 | + |
|
| 230 | + public function getOptionalInputShapeDefaults(): array { |
|
| 231 | + return []; |
|
| 232 | + } |
|
| 233 | + |
|
| 234 | + public function getOutputShapeEnumValues(): array { |
|
| 235 | + return []; |
|
| 236 | + } |
|
| 237 | + |
|
| 238 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 239 | + return []; |
|
| 240 | + } |
|
| 241 | + }; |
|
| 242 | + $newProviders[$provider->getId()] = $provider; |
|
| 243 | + } |
|
| 244 | + |
|
| 245 | + return $newProviders; |
|
| 246 | + } |
|
| 247 | + |
|
| 248 | + /** |
|
| 249 | + * @return ITaskType[] |
|
| 250 | + */ |
|
| 251 | + private function _getTextProcessingTaskTypes(): array { |
|
| 252 | + $oldProviders = $this->_getRawTextProcessingProviders(); |
|
| 253 | + $newTaskTypes = []; |
|
| 254 | + foreach ($oldProviders as $oldProvider) { |
|
| 255 | + // These are already implemented in the TaskProcessing realm |
|
| 256 | + if (in_array($oldProvider->getTaskType(), [ |
|
| 257 | + \OCP\TextProcessing\FreePromptTaskType::class, |
|
| 258 | + \OCP\TextProcessing\HeadlineTaskType::class, |
|
| 259 | + \OCP\TextProcessing\TopicsTaskType::class, |
|
| 260 | + \OCP\TextProcessing\SummaryTaskType::class |
|
| 261 | + ], true)) { |
|
| 262 | + continue; |
|
| 263 | + } |
|
| 264 | + $taskType = new class($oldProvider->getTaskType()) implements ITaskType { |
|
| 265 | + private string $oldTaskTypeClass; |
|
| 266 | + private \OCP\TextProcessing\ITaskType $oldTaskType; |
|
| 267 | + private IL10N $l; |
|
| 268 | + |
|
| 269 | + public function __construct(string $oldTaskTypeClass) { |
|
| 270 | + $this->oldTaskTypeClass = $oldTaskTypeClass; |
|
| 271 | + $this->oldTaskType = \OCP\Server::get($oldTaskTypeClass); |
|
| 272 | + $this->l = \OCP\Server::get(IFactory::class)->get('core'); |
|
| 273 | + } |
|
| 274 | + |
|
| 275 | + public function getId(): string { |
|
| 276 | + return Manager::LEGACY_PREFIX_TEXTPROCESSING . $this->oldTaskTypeClass; |
|
| 277 | + } |
|
| 278 | + |
|
| 279 | + public function getName(): string { |
|
| 280 | + return $this->oldTaskType->getName(); |
|
| 281 | + } |
|
| 282 | + |
|
| 283 | + public function getDescription(): string { |
|
| 284 | + return $this->oldTaskType->getDescription(); |
|
| 285 | + } |
|
| 286 | + |
|
| 287 | + public function getInputShape(): array { |
|
| 288 | + return ['input' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)]; |
|
| 289 | + } |
|
| 290 | + |
|
| 291 | + public function getOutputShape(): array { |
|
| 292 | + return ['output' => new ShapeDescriptor($this->l->t('Input text'), $this->l->t('The input text'), EShapeType::Text)]; |
|
| 293 | + } |
|
| 294 | + }; |
|
| 295 | + $newTaskTypes[$taskType->getId()] = $taskType; |
|
| 296 | + } |
|
| 297 | + |
|
| 298 | + return $newTaskTypes; |
|
| 299 | + } |
|
| 300 | + |
|
| 301 | + /** |
|
| 302 | + * @return IProvider[] |
|
| 303 | + */ |
|
| 304 | + private function _getTextToImageProviders(): array { |
|
| 305 | + $oldProviders = $this->textToImageManager->getProviders(); |
|
| 306 | + $newProviders = []; |
|
| 307 | + foreach ($oldProviders as $oldProvider) { |
|
| 308 | + $newProvider = new class($oldProvider, $this->appData) implements IProvider, ISynchronousProvider { |
|
| 309 | + private \OCP\TextToImage\IProvider $provider; |
|
| 310 | + private IAppData $appData; |
|
| 311 | + |
|
| 312 | + public function __construct(\OCP\TextToImage\IProvider $provider, IAppData $appData) { |
|
| 313 | + $this->provider = $provider; |
|
| 314 | + $this->appData = $appData; |
|
| 315 | + } |
|
| 316 | + |
|
| 317 | + public function getId(): string { |
|
| 318 | + return Manager::LEGACY_PREFIX_TEXTTOIMAGE . $this->provider->getId(); |
|
| 319 | + } |
|
| 320 | + |
|
| 321 | + public function getName(): string { |
|
| 322 | + return $this->provider->getName(); |
|
| 323 | + } |
|
| 324 | + |
|
| 325 | + public function getTaskTypeId(): string { |
|
| 326 | + return TextToImage::ID; |
|
| 327 | + } |
|
| 328 | + |
|
| 329 | + public function getExpectedRuntime(): int { |
|
| 330 | + return $this->provider->getExpectedRuntime(); |
|
| 331 | + } |
|
| 332 | + |
|
| 333 | + public function getOptionalInputShape(): array { |
|
| 334 | + return []; |
|
| 335 | + } |
|
| 336 | + |
|
| 337 | + public function getOptionalOutputShape(): array { |
|
| 338 | + return []; |
|
| 339 | + } |
|
| 340 | + |
|
| 341 | + public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 342 | + try { |
|
| 343 | + $folder = $this->appData->getFolder('text2image'); |
|
| 344 | + } catch (\OCP\Files\NotFoundException) { |
|
| 345 | + $folder = $this->appData->newFolder('text2image'); |
|
| 346 | + } |
|
| 347 | + $resources = []; |
|
| 348 | + $files = []; |
|
| 349 | + for ($i = 0; $i < $input['numberOfImages']; $i++) { |
|
| 350 | + $file = $folder->newFile(time() . '-' . rand(1, 100000) . '-' . $i); |
|
| 351 | + $files[] = $file; |
|
| 352 | + $resource = $file->write(); |
|
| 353 | + if ($resource !== false && $resource !== true && is_resource($resource)) { |
|
| 354 | + $resources[] = $resource; |
|
| 355 | + } else { |
|
| 356 | + throw new ProcessingException('Text2Image generation using provider "' . $this->getName() . '" failed: Couldn\'t open file to write.'); |
|
| 357 | + } |
|
| 358 | + } |
|
| 359 | + if ($this->provider instanceof \OCP\TextToImage\IProviderWithUserId) { |
|
| 360 | + $this->provider->setUserId($userId); |
|
| 361 | + } |
|
| 362 | + try { |
|
| 363 | + $this->provider->generate($input['input'], $resources); |
|
| 364 | + } catch (\RuntimeException $e) { |
|
| 365 | + throw new ProcessingException($e->getMessage(), 0, $e); |
|
| 366 | + } |
|
| 367 | + for ($i = 0; $i < $input['numberOfImages']; $i++) { |
|
| 368 | + if (is_resource($resources[$i])) { |
|
| 369 | + // If $resource hasn't been closed yet, we'll do that here |
|
| 370 | + fclose($resources[$i]); |
|
| 371 | + } |
|
| 372 | + } |
|
| 373 | + return ['images' => array_map(fn (ISimpleFile $file) => $file->getContent(), $files)]; |
|
| 374 | + } |
|
| 375 | + |
|
| 376 | + public function getInputShapeEnumValues(): array { |
|
| 377 | + return []; |
|
| 378 | + } |
|
| 379 | + |
|
| 380 | + public function getInputShapeDefaults(): array { |
|
| 381 | + return []; |
|
| 382 | + } |
|
| 383 | + |
|
| 384 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 385 | + return []; |
|
| 386 | + } |
|
| 387 | + |
|
| 388 | + public function getOptionalInputShapeDefaults(): array { |
|
| 389 | + return []; |
|
| 390 | + } |
|
| 391 | + |
|
| 392 | + public function getOutputShapeEnumValues(): array { |
|
| 393 | + return []; |
|
| 394 | + } |
|
| 395 | + |
|
| 396 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 397 | + return []; |
|
| 398 | + } |
|
| 399 | + }; |
|
| 400 | + $newProviders[$newProvider->getId()] = $newProvider; |
|
| 401 | + } |
|
| 402 | + |
|
| 403 | + return $newProviders; |
|
| 404 | + } |
|
| 405 | + |
|
| 406 | + /** |
|
| 407 | + * This is almost a copy of SpeechToTextManager->getProviders |
|
| 408 | + * to avoid a dependency cycle between SpeechToTextManager and TaskProcessingManager |
|
| 409 | + */ |
|
| 410 | + private function _getRawSpeechToTextProviders(): array { |
|
| 411 | + $context = $this->coordinator->getRegistrationContext(); |
|
| 412 | + if ($context === null) { |
|
| 413 | + return []; |
|
| 414 | + } |
|
| 415 | + $providers = []; |
|
| 416 | + foreach ($context->getSpeechToTextProviders() as $providerServiceRegistration) { |
|
| 417 | + $class = $providerServiceRegistration->getService(); |
|
| 418 | + try { |
|
| 419 | + $providers[$class] = $this->serverContainer->get($class); |
|
| 420 | + } catch (NotFoundExceptionInterface|ContainerExceptionInterface|\Throwable $e) { |
|
| 421 | + $this->logger->error('Failed to load SpeechToText provider ' . $class, [ |
|
| 422 | + 'exception' => $e, |
|
| 423 | + ]); |
|
| 424 | + } |
|
| 425 | + } |
|
| 426 | + |
|
| 427 | + return $providers; |
|
| 428 | + } |
|
| 429 | + |
|
| 430 | + /** |
|
| 431 | + * @return IProvider[] |
|
| 432 | + */ |
|
| 433 | + private function _getSpeechToTextProviders(): array { |
|
| 434 | + $oldProviders = $this->_getRawSpeechToTextProviders(); |
|
| 435 | + $newProviders = []; |
|
| 436 | + foreach ($oldProviders as $oldProvider) { |
|
| 437 | + $newProvider = new class($oldProvider, $this->rootFolder, $this->appData) implements IProvider, ISynchronousProvider { |
|
| 438 | + private ISpeechToTextProvider $provider; |
|
| 439 | + private IAppData $appData; |
|
| 440 | + |
|
| 441 | + private IRootFolder $rootFolder; |
|
| 442 | + |
|
| 443 | + public function __construct(ISpeechToTextProvider $provider, IRootFolder $rootFolder, IAppData $appData) { |
|
| 444 | + $this->provider = $provider; |
|
| 445 | + $this->rootFolder = $rootFolder; |
|
| 446 | + $this->appData = $appData; |
|
| 447 | + } |
|
| 448 | + |
|
| 449 | + public function getId(): string { |
|
| 450 | + if ($this->provider instanceof ISpeechToTextProviderWithId) { |
|
| 451 | + return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider->getId(); |
|
| 452 | + } |
|
| 453 | + return Manager::LEGACY_PREFIX_SPEECHTOTEXT . $this->provider::class; |
|
| 454 | + } |
|
| 455 | + |
|
| 456 | + public function getName(): string { |
|
| 457 | + return $this->provider->getName(); |
|
| 458 | + } |
|
| 459 | + |
|
| 460 | + public function getTaskTypeId(): string { |
|
| 461 | + return AudioToText::ID; |
|
| 462 | + } |
|
| 463 | + |
|
| 464 | + public function getExpectedRuntime(): int { |
|
| 465 | + return 60; |
|
| 466 | + } |
|
| 467 | + |
|
| 468 | + public function getOptionalInputShape(): array { |
|
| 469 | + return []; |
|
| 470 | + } |
|
| 471 | + |
|
| 472 | + public function getOptionalOutputShape(): array { |
|
| 473 | + return []; |
|
| 474 | + } |
|
| 475 | + |
|
| 476 | + public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 477 | + if ($this->provider instanceof \OCP\SpeechToText\ISpeechToTextProviderWithUserId) { |
|
| 478 | + $this->provider->setUserId($userId); |
|
| 479 | + } |
|
| 480 | + try { |
|
| 481 | + $result = $this->provider->transcribeFile($input['input']); |
|
| 482 | + } catch (\RuntimeException $e) { |
|
| 483 | + throw new ProcessingException($e->getMessage(), 0, $e); |
|
| 484 | + } |
|
| 485 | + return ['output' => $result]; |
|
| 486 | + } |
|
| 487 | + |
|
| 488 | + public function getInputShapeEnumValues(): array { |
|
| 489 | + return []; |
|
| 490 | + } |
|
| 491 | + |
|
| 492 | + public function getInputShapeDefaults(): array { |
|
| 493 | + return []; |
|
| 494 | + } |
|
| 495 | + |
|
| 496 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 497 | + return []; |
|
| 498 | + } |
|
| 499 | + |
|
| 500 | + public function getOptionalInputShapeDefaults(): array { |
|
| 501 | + return []; |
|
| 502 | + } |
|
| 503 | + |
|
| 504 | + public function getOutputShapeEnumValues(): array { |
|
| 505 | + return []; |
|
| 506 | + } |
|
| 507 | + |
|
| 508 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 509 | + return []; |
|
| 510 | + } |
|
| 511 | + }; |
|
| 512 | + $newProviders[$newProvider->getId()] = $newProvider; |
|
| 513 | + } |
|
| 514 | + |
|
| 515 | + return $newProviders; |
|
| 516 | + } |
|
| 517 | + |
|
| 518 | + /** |
|
| 519 | + * Dispatches the event to collect external providers and task types. |
|
| 520 | + * Caches the result within the request. |
|
| 521 | + */ |
|
| 522 | + private function dispatchGetProvidersEvent(): GetTaskProcessingProvidersEvent { |
|
| 523 | + if ($this->eventResult !== null) { |
|
| 524 | + return $this->eventResult; |
|
| 525 | + } |
|
| 526 | + |
|
| 527 | + $this->eventResult = new GetTaskProcessingProvidersEvent(); |
|
| 528 | + $this->dispatcher->dispatchTyped($this->eventResult); |
|
| 529 | + return $this->eventResult ; |
|
| 530 | + } |
|
| 531 | + |
|
| 532 | + /** |
|
| 533 | + * @return IProvider[] |
|
| 534 | + */ |
|
| 535 | + private function _getProviders(): array { |
|
| 536 | + $context = $this->coordinator->getRegistrationContext(); |
|
| 537 | + |
|
| 538 | + if ($context === null) { |
|
| 539 | + return []; |
|
| 540 | + } |
|
| 541 | + |
|
| 542 | + $providers = []; |
|
| 543 | + |
|
| 544 | + foreach ($context->getTaskProcessingProviders() as $providerServiceRegistration) { |
|
| 545 | + $class = $providerServiceRegistration->getService(); |
|
| 546 | + try { |
|
| 547 | + /** @var IProvider $provider */ |
|
| 548 | + $provider = $this->serverContainer->get($class); |
|
| 549 | + if (isset($providers[$provider->getId()])) { |
|
| 550 | + $this->logger->warning('Task processing provider ' . $class . ' is using ID ' . $provider->getId() . ' which is already used by ' . $providers[$provider->getId()]::class); |
|
| 551 | + } |
|
| 552 | + $providers[$provider->getId()] = $provider; |
|
| 553 | + } catch (\Throwable $e) { |
|
| 554 | + $this->logger->error('Failed to load task processing provider ' . $class, [ |
|
| 555 | + 'exception' => $e, |
|
| 556 | + ]); |
|
| 557 | + } |
|
| 558 | + } |
|
| 559 | + |
|
| 560 | + $event = $this->dispatchGetProvidersEvent(); |
|
| 561 | + $externalProviders = $event->getProviders(); |
|
| 562 | + foreach ($externalProviders as $provider) { |
|
| 563 | + if (!isset($providers[$provider->getId()])) { |
|
| 564 | + $providers[$provider->getId()] = $provider; |
|
| 565 | + } else { |
|
| 566 | + $this->logger->info('Skipping external task processing provider with ID ' . $provider->getId() . ' because a local provider with the same ID already exists.'); |
|
| 567 | + } |
|
| 568 | + } |
|
| 569 | + |
|
| 570 | + $providers += $this->_getTextProcessingProviders() + $this->_getTextToImageProviders() + $this->_getSpeechToTextProviders(); |
|
| 571 | + |
|
| 572 | + return $providers; |
|
| 573 | + } |
|
| 574 | + |
|
| 575 | + /** |
|
| 576 | + * @return ITaskType[] |
|
| 577 | + */ |
|
| 578 | + private function _getTaskTypes(): array { |
|
| 579 | + $context = $this->coordinator->getRegistrationContext(); |
|
| 580 | + |
|
| 581 | + if ($context === null) { |
|
| 582 | + return []; |
|
| 583 | + } |
|
| 584 | + |
|
| 585 | + if ($this->taskTypes !== null) { |
|
| 586 | + return $this->taskTypes; |
|
| 587 | + } |
|
| 588 | + |
|
| 589 | + // Default task types |
|
| 590 | + $taskTypes = [ |
|
| 591 | + \OCP\TaskProcessing\TaskTypes\TextToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToText::class), |
|
| 592 | + \OCP\TaskProcessing\TaskTypes\TextToTextTopics::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextTopics::class), |
|
| 593 | + \OCP\TaskProcessing\TaskTypes\TextToTextHeadline::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextHeadline::class), |
|
| 594 | + \OCP\TaskProcessing\TaskTypes\TextToTextSummary::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextSummary::class), |
|
| 595 | + \OCP\TaskProcessing\TaskTypes\TextToTextFormalization::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextFormalization::class), |
|
| 596 | + \OCP\TaskProcessing\TaskTypes\TextToTextSimplification::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextSimplification::class), |
|
| 597 | + \OCP\TaskProcessing\TaskTypes\TextToTextChat::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextChat::class), |
|
| 598 | + \OCP\TaskProcessing\TaskTypes\TextToTextTranslate::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextTranslate::class), |
|
| 599 | + \OCP\TaskProcessing\TaskTypes\TextToTextReformulation::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextReformulation::class), |
|
| 600 | + \OCP\TaskProcessing\TaskTypes\TextToImage::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToImage::class), |
|
| 601 | + \OCP\TaskProcessing\TaskTypes\AudioToText::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AudioToText::class), |
|
| 602 | + \OCP\TaskProcessing\TaskTypes\ContextWrite::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextWrite::class), |
|
| 603 | + \OCP\TaskProcessing\TaskTypes\GenerateEmoji::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\GenerateEmoji::class), |
|
| 604 | + \OCP\TaskProcessing\TaskTypes\TextToTextChangeTone::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextChangeTone::class), |
|
| 605 | + \OCP\TaskProcessing\TaskTypes\TextToTextChatWithTools::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextChatWithTools::class), |
|
| 606 | + \OCP\TaskProcessing\TaskTypes\ContextAgentInteraction::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextAgentInteraction::class), |
|
| 607 | + \OCP\TaskProcessing\TaskTypes\TextToTextProofread::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToTextProofread::class), |
|
| 608 | + \OCP\TaskProcessing\TaskTypes\TextToSpeech::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\TextToSpeech::class), |
|
| 609 | + \OCP\TaskProcessing\TaskTypes\AudioToAudioChat::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AudioToAudioChat::class), |
|
| 610 | + \OCP\TaskProcessing\TaskTypes\ContextAgentAudioInteraction::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\ContextAgentAudioInteraction::class), |
|
| 611 | + \OCP\TaskProcessing\TaskTypes\AnalyzeImages::ID => \OCP\Server::get(\OCP\TaskProcessing\TaskTypes\AnalyzeImages::class), |
|
| 612 | + ]; |
|
| 613 | + |
|
| 614 | + foreach ($context->getTaskProcessingTaskTypes() as $providerServiceRegistration) { |
|
| 615 | + $class = $providerServiceRegistration->getService(); |
|
| 616 | + try { |
|
| 617 | + /** @var ITaskType $provider */ |
|
| 618 | + $taskType = $this->serverContainer->get($class); |
|
| 619 | + if (isset($taskTypes[$taskType->getId()])) { |
|
| 620 | + $this->logger->warning('Task processing task type ' . $class . ' is using ID ' . $taskType->getId() . ' which is already used by ' . $taskTypes[$taskType->getId()]::class); |
|
| 621 | + } |
|
| 622 | + $taskTypes[$taskType->getId()] = $taskType; |
|
| 623 | + } catch (\Throwable $e) { |
|
| 624 | + $this->logger->error('Failed to load task processing task type ' . $class, [ |
|
| 625 | + 'exception' => $e, |
|
| 626 | + ]); |
|
| 627 | + } |
|
| 628 | + } |
|
| 629 | + |
|
| 630 | + $event = $this->dispatchGetProvidersEvent(); |
|
| 631 | + $externalTaskTypes = $event->getTaskTypes(); |
|
| 632 | + foreach ($externalTaskTypes as $taskType) { |
|
| 633 | + if (isset($taskTypes[$taskType->getId()])) { |
|
| 634 | + $this->logger->warning('External task processing task type is using ID ' . $taskType->getId() . ' which is already used by a locally registered task type (' . get_class($taskTypes[$taskType->getId()]) . ')'); |
|
| 635 | + } |
|
| 636 | + $taskTypes[$taskType->getId()] = $taskType; |
|
| 637 | + } |
|
| 638 | + |
|
| 639 | + $taskTypes += $this->_getTextProcessingTaskTypes(); |
|
| 640 | + |
|
| 641 | + $this->taskTypes = $taskTypes; |
|
| 642 | + return $this->taskTypes; |
|
| 643 | + } |
|
| 644 | + |
|
| 645 | + /** |
|
| 646 | + * @return array |
|
| 647 | + */ |
|
| 648 | + private function _getTaskTypeSettings(): array { |
|
| 649 | + try { |
|
| 650 | + $json = $this->appConfig->getValueString('core', 'ai.taskprocessing_type_preferences', '', lazy: true); |
|
| 651 | + if ($json === '') { |
|
| 652 | + return []; |
|
| 653 | + } |
|
| 654 | + return json_decode($json, true, flags: JSON_THROW_ON_ERROR); |
|
| 655 | + } catch (\JsonException $e) { |
|
| 656 | + $this->logger->error('Failed to get settings. JSON Error in ai.taskprocessing_type_preferences', ['exception' => $e]); |
|
| 657 | + $taskTypeSettings = []; |
|
| 658 | + $taskTypes = $this->_getTaskTypes(); |
|
| 659 | + foreach ($taskTypes as $taskType) { |
|
| 660 | + $taskTypeSettings[$taskType->getId()] = false; |
|
| 661 | + }; |
|
| 662 | + |
|
| 663 | + return $taskTypeSettings; |
|
| 664 | + } |
|
| 665 | + |
|
| 666 | + } |
|
| 667 | + |
|
| 668 | + /** |
|
| 669 | + * @param ShapeDescriptor[] $spec |
|
| 670 | + * @param array<array-key, string|numeric> $defaults |
|
| 671 | + * @param array<array-key, ShapeEnumValue[]> $enumValues |
|
| 672 | + * @param array $io |
|
| 673 | + * @param bool $optional |
|
| 674 | + * @return void |
|
| 675 | + * @throws ValidationException |
|
| 676 | + */ |
|
| 677 | + private static function validateInput(array $spec, array $defaults, array $enumValues, array $io, bool $optional = false): void { |
|
| 678 | + foreach ($spec as $key => $descriptor) { |
|
| 679 | + $type = $descriptor->getShapeType(); |
|
| 680 | + if (!isset($io[$key])) { |
|
| 681 | + if ($optional) { |
|
| 682 | + continue; |
|
| 683 | + } |
|
| 684 | + if (isset($defaults[$key])) { |
|
| 685 | + if (EShapeType::getScalarType($type) !== $type) { |
|
| 686 | + throw new ValidationException('Provider tried to set a default value for a non-scalar slot'); |
|
| 687 | + } |
|
| 688 | + if (EShapeType::isFileType($type)) { |
|
| 689 | + throw new ValidationException('Provider tried to set a default value for a slot that is not text or number'); |
|
| 690 | + } |
|
| 691 | + $type->validateInput($defaults[$key]); |
|
| 692 | + continue; |
|
| 693 | + } |
|
| 694 | + throw new ValidationException('Missing key: "' . $key . '"'); |
|
| 695 | + } |
|
| 696 | + try { |
|
| 697 | + $type->validateInput($io[$key]); |
|
| 698 | + if ($type === EShapeType::Enum) { |
|
| 699 | + if (!isset($enumValues[$key])) { |
|
| 700 | + throw new ValidationException('Provider did not provide enum values for an enum slot: "' . $key . '"'); |
|
| 701 | + } |
|
| 702 | + $type->validateEnum($io[$key], $enumValues[$key]); |
|
| 703 | + } |
|
| 704 | + } catch (ValidationException $e) { |
|
| 705 | + throw new ValidationException('Failed to validate input key "' . $key . '": ' . $e->getMessage()); |
|
| 706 | + } |
|
| 707 | + } |
|
| 708 | + } |
|
| 709 | + |
|
| 710 | + /** |
|
| 711 | + * Takes task input data and replaces fileIds with File objects |
|
| 712 | + * |
|
| 713 | + * @param array<array-key, list<numeric|string>|numeric|string> $input |
|
| 714 | + * @param array<array-key, numeric|string> ...$defaultSpecs the specs |
|
| 715 | + * @return array<array-key, list<numeric|string>|numeric|string> |
|
| 716 | + */ |
|
| 717 | + public function fillInputDefaults(array $input, ...$defaultSpecs): array { |
|
| 718 | + $spec = array_reduce($defaultSpecs, fn ($carry, $spec) => array_merge($carry, $spec), []); |
|
| 719 | + return array_merge($spec, $input); |
|
| 720 | + } |
|
| 721 | + |
|
| 722 | + /** |
|
| 723 | + * @param ShapeDescriptor[] $spec |
|
| 724 | + * @param array<array-key, ShapeEnumValue[]> $enumValues |
|
| 725 | + * @param array $io |
|
| 726 | + * @param bool $optional |
|
| 727 | + * @return void |
|
| 728 | + * @throws ValidationException |
|
| 729 | + */ |
|
| 730 | + private static function validateOutputWithFileIds(array $spec, array $enumValues, array $io, bool $optional = false): void { |
|
| 731 | + foreach ($spec as $key => $descriptor) { |
|
| 732 | + $type = $descriptor->getShapeType(); |
|
| 733 | + if (!isset($io[$key])) { |
|
| 734 | + if ($optional) { |
|
| 735 | + continue; |
|
| 736 | + } |
|
| 737 | + throw new ValidationException('Missing key: "' . $key . '"'); |
|
| 738 | + } |
|
| 739 | + try { |
|
| 740 | + $type->validateOutputWithFileIds($io[$key]); |
|
| 741 | + if (isset($enumValues[$key])) { |
|
| 742 | + $type->validateEnum($io[$key], $enumValues[$key]); |
|
| 743 | + } |
|
| 744 | + } catch (ValidationException $e) { |
|
| 745 | + throw new ValidationException('Failed to validate output key "' . $key . '": ' . $e->getMessage()); |
|
| 746 | + } |
|
| 747 | + } |
|
| 748 | + } |
|
| 749 | + |
|
| 750 | + /** |
|
| 751 | + * @param ShapeDescriptor[] $spec |
|
| 752 | + * @param array<array-key, ShapeEnumValue[]> $enumValues |
|
| 753 | + * @param array $io |
|
| 754 | + * @param bool $optional |
|
| 755 | + * @return void |
|
| 756 | + * @throws ValidationException |
|
| 757 | + */ |
|
| 758 | + private static function validateOutputWithFileData(array $spec, array $enumValues, array $io, bool $optional = false): void { |
|
| 759 | + foreach ($spec as $key => $descriptor) { |
|
| 760 | + $type = $descriptor->getShapeType(); |
|
| 761 | + if (!isset($io[$key])) { |
|
| 762 | + if ($optional) { |
|
| 763 | + continue; |
|
| 764 | + } |
|
| 765 | + throw new ValidationException('Missing key: "' . $key . '"'); |
|
| 766 | + } |
|
| 767 | + try { |
|
| 768 | + $type->validateOutputWithFileData($io[$key]); |
|
| 769 | + if (isset($enumValues[$key])) { |
|
| 770 | + $type->validateEnum($io[$key], $enumValues[$key]); |
|
| 771 | + } |
|
| 772 | + } catch (ValidationException $e) { |
|
| 773 | + throw new ValidationException('Failed to validate output key "' . $key . '": ' . $e->getMessage()); |
|
| 774 | + } |
|
| 775 | + } |
|
| 776 | + } |
|
| 777 | + |
|
| 778 | + /** |
|
| 779 | + * @param array<array-key, T> $array The array to filter |
|
| 780 | + * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep |
|
| 781 | + * @return array<array-key, T> |
|
| 782 | + * @psalm-template T |
|
| 783 | + */ |
|
| 784 | + private function removeSuperfluousArrayKeys(array $array, ...$specs): array { |
|
| 785 | + $keys = array_unique(array_reduce($specs, fn ($carry, $spec) => array_merge($carry, array_keys($spec)), [])); |
|
| 786 | + $keys = array_filter($keys, fn ($key) => array_key_exists($key, $array)); |
|
| 787 | + $values = array_map(fn (string $key) => $array[$key], $keys); |
|
| 788 | + return array_combine($keys, $values); |
|
| 789 | + } |
|
| 790 | + |
|
| 791 | + public function hasProviders(): bool { |
|
| 792 | + return count($this->getProviders()) !== 0; |
|
| 793 | + } |
|
| 794 | + |
|
| 795 | + public function getProviders(): array { |
|
| 796 | + if ($this->providers === null) { |
|
| 797 | + $this->providers = $this->_getProviders(); |
|
| 798 | + } |
|
| 799 | + |
|
| 800 | + return $this->providers; |
|
| 801 | + } |
|
| 802 | + |
|
| 803 | + public function getPreferredProvider(string $taskTypeId) { |
|
| 804 | + try { |
|
| 805 | + if ($this->preferences === null) { |
|
| 806 | + $this->preferences = $this->distributedCache->get('ai.taskprocessing_provider_preferences'); |
|
| 807 | + if ($this->preferences === null) { |
|
| 808 | + $this->preferences = json_decode( |
|
| 809 | + $this->appConfig->getValueString('core', 'ai.taskprocessing_provider_preferences', 'null', lazy: true), |
|
| 810 | + associative: true, |
|
| 811 | + flags: JSON_THROW_ON_ERROR, |
|
| 812 | + ); |
|
| 813 | + $this->distributedCache->set('ai.taskprocessing_provider_preferences', $this->preferences, 60 * 3); |
|
| 814 | + } |
|
| 815 | + } |
|
| 816 | + |
|
| 817 | + $providers = $this->getProviders(); |
|
| 818 | + if (isset($this->preferences[$taskTypeId])) { |
|
| 819 | + $providersById = $this->providersById ?? array_reduce($providers, static function (array $carry, IProvider $provider) { |
|
| 820 | + $carry[$provider->getId()] = $provider; |
|
| 821 | + return $carry; |
|
| 822 | + }, []); |
|
| 823 | + $this->providersById = $providersById; |
|
| 824 | + if (isset($providersById[$this->preferences[$taskTypeId]])) { |
|
| 825 | + return $providersById[$this->preferences[$taskTypeId]]; |
|
| 826 | + } |
|
| 827 | + } |
|
| 828 | + // By default, use the first available provider |
|
| 829 | + foreach ($providers as $provider) { |
|
| 830 | + if ($provider->getTaskTypeId() === $taskTypeId) { |
|
| 831 | + return $provider; |
|
| 832 | + } |
|
| 833 | + } |
|
| 834 | + } catch (\JsonException $e) { |
|
| 835 | + $this->logger->warning('Failed to parse provider preferences while getting preferred provider for task type ' . $taskTypeId, ['exception' => $e]); |
|
| 836 | + } |
|
| 837 | + throw new \OCP\TaskProcessing\Exception\Exception('No matching provider found'); |
|
| 838 | + } |
|
| 839 | + |
|
| 840 | + public function getAvailableTaskTypes(bool $showDisabled = false, ?string $userId = null): array { |
|
| 841 | + // We cache by language, because some task type fields are translated |
|
| 842 | + $cacheKey = self::TASK_TYPES_CACHE_KEY . ':' . $this->l10nFactory->findLanguage(); |
|
| 843 | + |
|
| 844 | + // userId will be obtained from the session if left to null |
|
| 845 | + if (!$this->checkGuestAccess($userId)) { |
|
| 846 | + return []; |
|
| 847 | + } |
|
| 848 | + if ($this->availableTaskTypes === null) { |
|
| 849 | + $cachedValue = $this->distributedCache->get($cacheKey); |
|
| 850 | + if ($cachedValue !== null) { |
|
| 851 | + $this->availableTaskTypes = unserialize($cachedValue); |
|
| 852 | + } |
|
| 853 | + } |
|
| 854 | + // Either we have no cache or showDisabled is turned on, which we don't want to cache, ever. |
|
| 855 | + if ($this->availableTaskTypes === null || $showDisabled) { |
|
| 856 | + $taskTypes = $this->_getTaskTypes(); |
|
| 857 | + $taskTypeSettings = $this->_getTaskTypeSettings(); |
|
| 858 | + |
|
| 859 | + $availableTaskTypes = []; |
|
| 860 | + foreach ($taskTypes as $taskType) { |
|
| 861 | + if ((!$showDisabled) && isset($taskTypeSettings[$taskType->getId()]) && !$taskTypeSettings[$taskType->getId()]) { |
|
| 862 | + continue; |
|
| 863 | + } |
|
| 864 | + try { |
|
| 865 | + $provider = $this->getPreferredProvider($taskType->getId()); |
|
| 866 | + } catch (\OCP\TaskProcessing\Exception\Exception $e) { |
|
| 867 | + continue; |
|
| 868 | + } |
|
| 869 | + try { |
|
| 870 | + $availableTaskTypes[$provider->getTaskTypeId()] = [ |
|
| 871 | + 'name' => $taskType->getName(), |
|
| 872 | + 'description' => $taskType->getDescription(), |
|
| 873 | + 'optionalInputShape' => $provider->getOptionalInputShape(), |
|
| 874 | + 'inputShapeEnumValues' => $provider->getInputShapeEnumValues(), |
|
| 875 | + 'inputShapeDefaults' => $provider->getInputShapeDefaults(), |
|
| 876 | + 'inputShape' => $taskType->getInputShape(), |
|
| 877 | + 'optionalInputShapeEnumValues' => $provider->getOptionalInputShapeEnumValues(), |
|
| 878 | + 'optionalInputShapeDefaults' => $provider->getOptionalInputShapeDefaults(), |
|
| 879 | + 'outputShape' => $taskType->getOutputShape(), |
|
| 880 | + 'outputShapeEnumValues' => $provider->getOutputShapeEnumValues(), |
|
| 881 | + 'optionalOutputShape' => $provider->getOptionalOutputShape(), |
|
| 882 | + 'optionalOutputShapeEnumValues' => $provider->getOptionalOutputShapeEnumValues(), |
|
| 883 | + 'isInternal' => $taskType instanceof IInternalTaskType, |
|
| 884 | + ]; |
|
| 885 | + } catch (\Throwable $e) { |
|
| 886 | + $this->logger->error('Failed to set up TaskProcessing provider ' . $provider::class, ['exception' => $e]); |
|
| 887 | + } |
|
| 888 | + } |
|
| 889 | + |
|
| 890 | + if ($showDisabled) { |
|
| 891 | + // Do not cache showDisabled, ever. |
|
| 892 | + return $availableTaskTypes; |
|
| 893 | + } |
|
| 894 | + |
|
| 895 | + $this->availableTaskTypes = $availableTaskTypes; |
|
| 896 | + $this->distributedCache->set($cacheKey, serialize($this->availableTaskTypes), 60); |
|
| 897 | + } |
|
| 898 | + |
|
| 899 | + |
|
| 900 | + return $this->availableTaskTypes; |
|
| 901 | + } |
|
| 902 | + public function getAvailableTaskTypeIds(bool $showDisabled = false, ?string $userId = null): array { |
|
| 903 | + // userId will be obtained from the session if left to null |
|
| 904 | + if (!$this->checkGuestAccess($userId)) { |
|
| 905 | + return []; |
|
| 906 | + } |
|
| 907 | + if ($this->availableTaskTypeIds === null) { |
|
| 908 | + $cachedValue = $this->distributedCache->get(self::TASK_TYPE_IDS_CACHE_KEY); |
|
| 909 | + if ($cachedValue !== null) { |
|
| 910 | + $this->availableTaskTypeIds = $cachedValue; |
|
| 911 | + } |
|
| 912 | + } |
|
| 913 | + // Either we have no cache or showDisabled is turned on, which we don't want to cache, ever. |
|
| 914 | + if ($this->availableTaskTypeIds === null || $showDisabled) { |
|
| 915 | + $taskTypes = $this->_getTaskTypes(); |
|
| 916 | + $taskTypeSettings = $this->_getTaskTypeSettings(); |
|
| 917 | + |
|
| 918 | + $availableTaskTypeIds = []; |
|
| 919 | + foreach ($taskTypes as $taskType) { |
|
| 920 | + if ((!$showDisabled) && isset($taskTypeSettings[$taskType->getId()]) && !$taskTypeSettings[$taskType->getId()]) { |
|
| 921 | + continue; |
|
| 922 | + } |
|
| 923 | + try { |
|
| 924 | + $provider = $this->getPreferredProvider($taskType->getId()); |
|
| 925 | + } catch (\OCP\TaskProcessing\Exception\Exception $e) { |
|
| 926 | + continue; |
|
| 927 | + } |
|
| 928 | + $availableTaskTypeIds[] = $taskType->getId(); |
|
| 929 | + } |
|
| 930 | + |
|
| 931 | + if ($showDisabled) { |
|
| 932 | + // Do not cache showDisabled, ever. |
|
| 933 | + return $availableTaskTypeIds; |
|
| 934 | + } |
|
| 935 | + |
|
| 936 | + $this->availableTaskTypeIds = $availableTaskTypeIds; |
|
| 937 | + $this->distributedCache->set(self::TASK_TYPE_IDS_CACHE_KEY, $this->availableTaskTypeIds, 60); |
|
| 938 | + } |
|
| 939 | + |
|
| 940 | + |
|
| 941 | + return $this->availableTaskTypeIds; |
|
| 942 | + } |
|
| 943 | + |
|
| 944 | + public function canHandleTask(Task $task): bool { |
|
| 945 | + return isset($this->getAvailableTaskTypes()[$task->getTaskTypeId()]); |
|
| 946 | + } |
|
| 947 | + |
|
| 948 | + private function checkGuestAccess(?string $userId = null): bool { |
|
| 949 | + if ($userId === null && !$this->userSession->isLoggedIn()) { |
|
| 950 | + return true; |
|
| 951 | + } |
|
| 952 | + if ($userId === null) { |
|
| 953 | + $user = $this->userSession->getUser(); |
|
| 954 | + } else { |
|
| 955 | + $user = $this->userManager->get($userId); |
|
| 956 | + } |
|
| 957 | + |
|
| 958 | + $guestsAllowed = $this->appConfig->getValueString('core', 'ai.taskprocessing_guests', 'false'); |
|
| 959 | + if ($guestsAllowed == 'true' || !class_exists(\OCA\Guests\UserBackend::class) || !($user->getBackend() instanceof \OCA\Guests\UserBackend)) { |
|
| 960 | + return true; |
|
| 961 | + } |
|
| 962 | + return false; |
|
| 963 | + } |
|
| 964 | + |
|
| 965 | + public function scheduleTask(Task $task): void { |
|
| 966 | + if (!$this->checkGuestAccess($task->getUserId())) { |
|
| 967 | + throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('Access to this resource is forbidden for guests.'); |
|
| 968 | + } |
|
| 969 | + if (!$this->canHandleTask($task)) { |
|
| 970 | + throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId()); |
|
| 971 | + } |
|
| 972 | + $this->prepareTask($task); |
|
| 973 | + $task->setStatus(Task::STATUS_SCHEDULED); |
|
| 974 | + $this->storeTask($task); |
|
| 975 | + // schedule synchronous job if the provider is synchronous |
|
| 976 | + $provider = $this->getPreferredProvider($task->getTaskTypeId()); |
|
| 977 | + if ($provider instanceof ISynchronousProvider) { |
|
| 978 | + $this->jobList->add(SynchronousBackgroundJob::class, null); |
|
| 979 | + } |
|
| 980 | + if ($provider instanceof ITriggerableProvider) { |
|
| 981 | + try { |
|
| 982 | + if (!$this->taskMapper->hasRunningTasksForTaskType($task->getTaskTypeId())) { |
|
| 983 | + // If no tasks are currently running for this task type, nudge the provider to ask for tasks |
|
| 984 | + try { |
|
| 985 | + $provider->trigger(); |
|
| 986 | + } catch (\Throwable $e) { |
|
| 987 | + $this->logger->error('Failed to trigger the provider after scheduling a task.', [ |
|
| 988 | + 'exception' => $e, |
|
| 989 | + 'taskId' => $task->getId(), |
|
| 990 | + 'providerId' => $provider->getId(), |
|
| 991 | + ]); |
|
| 992 | + } |
|
| 993 | + } |
|
| 994 | + } catch (Exception $e) { |
|
| 995 | + $this->logger->error('Failed to check DB for running tasks after a task was scheduled for a triggerable provider. Not triggering the provider.', [ |
|
| 996 | + 'exception' => $e, |
|
| 997 | + 'taskId' => $task->getId(), |
|
| 998 | + 'providerId' => $provider->getId() |
|
| 999 | + ]); |
|
| 1000 | + } |
|
| 1001 | + } |
|
| 1002 | + } |
|
| 1003 | + |
|
| 1004 | + public function runTask(Task $task): Task { |
|
| 1005 | + if (!$this->checkGuestAccess($task->getUserId())) { |
|
| 1006 | + throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('Access to this resource is forbidden for guests.'); |
|
| 1007 | + } |
|
| 1008 | + if (!$this->canHandleTask($task)) { |
|
| 1009 | + throw new \OCP\TaskProcessing\Exception\PreConditionNotMetException('No task processing provider is installed that can handle this task type: ' . $task->getTaskTypeId()); |
|
| 1010 | + } |
|
| 1011 | + |
|
| 1012 | + $provider = $this->getPreferredProvider($task->getTaskTypeId()); |
|
| 1013 | + if ($provider instanceof ISynchronousProvider) { |
|
| 1014 | + $this->prepareTask($task); |
|
| 1015 | + $task->setStatus(Task::STATUS_SCHEDULED); |
|
| 1016 | + $this->storeTask($task); |
|
| 1017 | + $this->processTask($task, $provider); |
|
| 1018 | + $task = $this->getTask($task->getId()); |
|
| 1019 | + } else { |
|
| 1020 | + $this->scheduleTask($task); |
|
| 1021 | + // poll task |
|
| 1022 | + while ($task->getStatus() === Task::STATUS_SCHEDULED || $task->getStatus() === Task::STATUS_RUNNING) { |
|
| 1023 | + sleep(1); |
|
| 1024 | + $task = $this->getTask($task->getId()); |
|
| 1025 | + } |
|
| 1026 | + } |
|
| 1027 | + return $task; |
|
| 1028 | + } |
|
| 1029 | + |
|
| 1030 | + public function processTask(Task $task, ISynchronousProvider $provider): bool { |
|
| 1031 | + try { |
|
| 1032 | + try { |
|
| 1033 | + $input = $this->prepareInputData($task); |
|
| 1034 | + } catch (GenericFileException|NotPermittedException|LockedException|ValidationException|UnauthorizedException $e) { |
|
| 1035 | + $this->logger->warning('Failed to prepare input data for a TaskProcessing task with synchronous provider ' . $provider->getId(), ['exception' => $e]); |
|
| 1036 | + $this->setTaskResult($task->getId(), $e->getMessage(), null); |
|
| 1037 | + return false; |
|
| 1038 | + } |
|
| 1039 | + try { |
|
| 1040 | + $this->setTaskStatus($task, Task::STATUS_RUNNING); |
|
| 1041 | + $output = $provider->process($task->getUserId(), $input, fn (float $progress) => $this->setTaskProgress($task->getId(), $progress)); |
|
| 1042 | + } catch (ProcessingException $e) { |
|
| 1043 | + $this->logger->warning('Failed to process a TaskProcessing task with synchronous provider ' . $provider->getId(), ['exception' => $e]); |
|
| 1044 | + $this->setTaskResult($task->getId(), $e->getMessage(), null); |
|
| 1045 | + return false; |
|
| 1046 | + } catch (\Throwable $e) { |
|
| 1047 | + $this->logger->error('Unknown error while processing TaskProcessing task', ['exception' => $e]); |
|
| 1048 | + $this->setTaskResult($task->getId(), $e->getMessage(), null); |
|
| 1049 | + return false; |
|
| 1050 | + } |
|
| 1051 | + $this->setTaskResult($task->getId(), null, $output); |
|
| 1052 | + } catch (NotFoundException $e) { |
|
| 1053 | + $this->logger->info('Could not find task anymore after execution. Moving on.', ['exception' => $e]); |
|
| 1054 | + } catch (Exception $e) { |
|
| 1055 | + $this->logger->error('Failed to report result of TaskProcessing task', ['exception' => $e]); |
|
| 1056 | + } |
|
| 1057 | + return true; |
|
| 1058 | + } |
|
| 1059 | + |
|
| 1060 | + public function deleteTask(Task $task): void { |
|
| 1061 | + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1062 | + $this->taskMapper->delete($taskEntity); |
|
| 1063 | + } |
|
| 1064 | + |
|
| 1065 | + public function getTask(int $id): Task { |
|
| 1066 | + try { |
|
| 1067 | + $taskEntity = $this->taskMapper->find($id); |
|
| 1068 | + return $taskEntity->toPublicTask(); |
|
| 1069 | + } catch (DoesNotExistException $e) { |
|
| 1070 | + throw new NotFoundException('Couldn\'t find task with id ' . $id, 0, $e); |
|
| 1071 | + } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) { |
|
| 1072 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1073 | + } catch (\JsonException $e) { |
|
| 1074 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e); |
|
| 1075 | + } |
|
| 1076 | + } |
|
| 1077 | + |
|
| 1078 | + public function cancelTask(int $id): void { |
|
| 1079 | + $task = $this->getTask($id); |
|
| 1080 | + if ($task->getStatus() !== Task::STATUS_SCHEDULED && $task->getStatus() !== Task::STATUS_RUNNING) { |
|
| 1081 | + return; |
|
| 1082 | + } |
|
| 1083 | + $task->setStatus(Task::STATUS_CANCELLED); |
|
| 1084 | + $task->setEndedAt(time()); |
|
| 1085 | + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1086 | + try { |
|
| 1087 | + $this->taskMapper->update($taskEntity); |
|
| 1088 | + $this->runWebhook($task); |
|
| 1089 | + } catch (\OCP\DB\Exception $e) { |
|
| 1090 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1091 | + } |
|
| 1092 | + } |
|
| 1093 | + |
|
| 1094 | + public function setTaskProgress(int $id, float $progress): bool { |
|
| 1095 | + // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently |
|
| 1096 | + $task = $this->getTask($id); |
|
| 1097 | + if ($task->getStatus() === Task::STATUS_CANCELLED) { |
|
| 1098 | + return false; |
|
| 1099 | + } |
|
| 1100 | + // only set the start time if the task is going from scheduled to running |
|
| 1101 | + if ($task->getstatus() === Task::STATUS_SCHEDULED) { |
|
| 1102 | + $task->setStartedAt(time()); |
|
| 1103 | + } |
|
| 1104 | + $task->setStatus(Task::STATUS_RUNNING); |
|
| 1105 | + $task->setProgress($progress); |
|
| 1106 | + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1107 | + try { |
|
| 1108 | + $this->taskMapper->update($taskEntity); |
|
| 1109 | + } catch (\OCP\DB\Exception $e) { |
|
| 1110 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1111 | + } |
|
| 1112 | + return true; |
|
| 1113 | + } |
|
| 1114 | + |
|
| 1115 | + public function setTaskResult(int $id, ?string $error, ?array $result, bool $isUsingFileIds = false): void { |
|
| 1116 | + // TODO: Not sure if we should rather catch the exceptions of getTask here and fail silently |
|
| 1117 | + $task = $this->getTask($id); |
|
| 1118 | + if ($task->getStatus() === Task::STATUS_CANCELLED) { |
|
| 1119 | + $this->logger->info('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' finished but was cancelled in the mean time. Moving on without storing result.'); |
|
| 1120 | + return; |
|
| 1121 | + } |
|
| 1122 | + if ($error !== null) { |
|
| 1123 | + $task->setStatus(Task::STATUS_FAILED); |
|
| 1124 | + $task->setEndedAt(time()); |
|
| 1125 | + // truncate error message to 1000 characters |
|
| 1126 | + $task->setErrorMessage(mb_substr($error, 0, 1000)); |
|
| 1127 | + $this->logger->warning('A TaskProcessing ' . $task->getTaskTypeId() . ' task with id ' . $id . ' failed with the following message: ' . $error); |
|
| 1128 | + } elseif ($result !== null) { |
|
| 1129 | + $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1130 | + $outputShape = $taskTypes[$task->getTaskTypeId()]['outputShape']; |
|
| 1131 | + $outputShapeEnumValues = $taskTypes[$task->getTaskTypeId()]['outputShapeEnumValues']; |
|
| 1132 | + $optionalOutputShape = $taskTypes[$task->getTaskTypeId()]['optionalOutputShape']; |
|
| 1133 | + $optionalOutputShapeEnumValues = $taskTypes[$task->getTaskTypeId()]['optionalOutputShapeEnumValues']; |
|
| 1134 | + try { |
|
| 1135 | + // validate output |
|
| 1136 | + if (!$isUsingFileIds) { |
|
| 1137 | + $this->validateOutputWithFileData($outputShape, $outputShapeEnumValues, $result); |
|
| 1138 | + $this->validateOutputWithFileData($optionalOutputShape, $optionalOutputShapeEnumValues, $result, true); |
|
| 1139 | + } else { |
|
| 1140 | + $this->validateOutputWithFileIds($outputShape, $outputShapeEnumValues, $result); |
|
| 1141 | + $this->validateOutputWithFileIds($optionalOutputShape, $optionalOutputShapeEnumValues, $result, true); |
|
| 1142 | + } |
|
| 1143 | + $output = $this->removeSuperfluousArrayKeys($result, $outputShape, $optionalOutputShape); |
|
| 1144 | + // extract raw data and put it in files, replace it with file ids |
|
| 1145 | + if (!$isUsingFileIds) { |
|
| 1146 | + $output = $this->encapsulateOutputFileData($output, $outputShape, $optionalOutputShape); |
|
| 1147 | + } else { |
|
| 1148 | + $this->validateOutputFileIds($output, $outputShape, $optionalOutputShape); |
|
| 1149 | + } |
|
| 1150 | + // Turn file objects into IDs |
|
| 1151 | + foreach ($output as $key => $value) { |
|
| 1152 | + if ($value instanceof Node) { |
|
| 1153 | + $output[$key] = $value->getId(); |
|
| 1154 | + } |
|
| 1155 | + if (is_array($value) && isset($value[0]) && $value[0] instanceof Node) { |
|
| 1156 | + $output[$key] = array_map(fn ($node) => $node->getId(), $value); |
|
| 1157 | + } |
|
| 1158 | + } |
|
| 1159 | + $task->setOutput($output); |
|
| 1160 | + $task->setProgress(1); |
|
| 1161 | + $task->setStatus(Task::STATUS_SUCCESSFUL); |
|
| 1162 | + $task->setEndedAt(time()); |
|
| 1163 | + } catch (ValidationException $e) { |
|
| 1164 | + $task->setProgress(1); |
|
| 1165 | + $task->setStatus(Task::STATUS_FAILED); |
|
| 1166 | + $task->setEndedAt(time()); |
|
| 1167 | + $error = 'The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec'; |
|
| 1168 | + $task->setErrorMessage($error); |
|
| 1169 | + $this->logger->error($error, ['exception' => $e, 'output' => $result]); |
|
| 1170 | + } catch (NotPermittedException $e) { |
|
| 1171 | + $task->setProgress(1); |
|
| 1172 | + $task->setStatus(Task::STATUS_FAILED); |
|
| 1173 | + $task->setEndedAt(time()); |
|
| 1174 | + $error = 'The task was processed successfully but storing the output in a file failed'; |
|
| 1175 | + $task->setErrorMessage($error); |
|
| 1176 | + $this->logger->error($error, ['exception' => $e]); |
|
| 1177 | + } catch (InvalidPathException|\OCP\Files\NotFoundException $e) { |
|
| 1178 | + $task->setProgress(1); |
|
| 1179 | + $task->setStatus(Task::STATUS_FAILED); |
|
| 1180 | + $task->setEndedAt(time()); |
|
| 1181 | + $error = 'The task was processed successfully but the result file could not be found'; |
|
| 1182 | + $task->setErrorMessage($error); |
|
| 1183 | + $this->logger->error($error, ['exception' => $e]); |
|
| 1184 | + } |
|
| 1185 | + } |
|
| 1186 | + try { |
|
| 1187 | + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1188 | + } catch (\JsonException $e) { |
|
| 1189 | + throw new \OCP\TaskProcessing\Exception\Exception('The task was processed successfully but the provider\'s output could not be encoded as JSON for the database.', 0, $e); |
|
| 1190 | + } |
|
| 1191 | + try { |
|
| 1192 | + $this->taskMapper->update($taskEntity); |
|
| 1193 | + $this->runWebhook($task); |
|
| 1194 | + } catch (\OCP\DB\Exception $e) { |
|
| 1195 | + throw new \OCP\TaskProcessing\Exception\Exception($e->getMessage()); |
|
| 1196 | + } |
|
| 1197 | + if ($task->getStatus() === Task::STATUS_SUCCESSFUL) { |
|
| 1198 | + $event = new TaskSuccessfulEvent($task); |
|
| 1199 | + } else { |
|
| 1200 | + $event = new TaskFailedEvent($task, $error); |
|
| 1201 | + } |
|
| 1202 | + $this->dispatcher->dispatchTyped($event); |
|
| 1203 | + } |
|
| 1204 | + |
|
| 1205 | + public function getNextScheduledTask(array $taskTypeIds = [], array $taskIdsToIgnore = []): Task { |
|
| 1206 | + try { |
|
| 1207 | + $taskEntity = $this->taskMapper->findOldestScheduledByType($taskTypeIds, $taskIdsToIgnore); |
|
| 1208 | + return $taskEntity->toPublicTask(); |
|
| 1209 | + } catch (DoesNotExistException $e) { |
|
| 1210 | + throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', previous: $e); |
|
| 1211 | + } catch (\OCP\DB\Exception $e) { |
|
| 1212 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', previous: $e); |
|
| 1213 | + } catch (\JsonException $e) { |
|
| 1214 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', previous: $e); |
|
| 1215 | + } |
|
| 1216 | + } |
|
| 1217 | + |
|
| 1218 | + public function getNextScheduledTasks(array $taskTypeIds = [], array $taskIdsToIgnore = [], int $numberOfTasks = 1): array { |
|
| 1219 | + try { |
|
| 1220 | + return array_map(fn ($taskEntity) => $taskEntity->toPublicTask(), $this->taskMapper->findNOldestScheduledByType($taskTypeIds, $taskIdsToIgnore, $numberOfTasks)); |
|
| 1221 | + } catch (DoesNotExistException $e) { |
|
| 1222 | + throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', previous: $e); |
|
| 1223 | + } catch (\OCP\DB\Exception $e) { |
|
| 1224 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', previous: $e); |
|
| 1225 | + } catch (\JsonException $e) { |
|
| 1226 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', previous: $e); |
|
| 1227 | + } |
|
| 1228 | + } |
|
| 1229 | + |
|
| 1230 | + /** |
|
| 1231 | + * Takes task input data and replaces fileIds with File objects |
|
| 1232 | + * |
|
| 1233 | + * @param string|null $userId |
|
| 1234 | + * @param array<array-key, list<numeric|string>|numeric|string> $input |
|
| 1235 | + * @param ShapeDescriptor[] ...$specs the specs |
|
| 1236 | + * @return array<array-key, list<File|numeric|string>|numeric|string|File> |
|
| 1237 | + * @throws GenericFileException|LockedException|NotPermittedException|ValidationException|UnauthorizedException |
|
| 1238 | + */ |
|
| 1239 | + public function fillInputFileData(?string $userId, array $input, ...$specs): array { |
|
| 1240 | + if ($userId !== null) { |
|
| 1241 | + \OC_Util::setupFS($userId); |
|
| 1242 | + } |
|
| 1243 | + $newInputOutput = []; |
|
| 1244 | + $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []); |
|
| 1245 | + foreach ($spec as $key => $descriptor) { |
|
| 1246 | + $type = $descriptor->getShapeType(); |
|
| 1247 | + if (!isset($input[$key])) { |
|
| 1248 | + continue; |
|
| 1249 | + } |
|
| 1250 | + if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) { |
|
| 1251 | + $newInputOutput[$key] = $input[$key]; |
|
| 1252 | + continue; |
|
| 1253 | + } |
|
| 1254 | + if (EShapeType::getScalarType($type) === $type) { |
|
| 1255 | + // is scalar |
|
| 1256 | + $node = $this->validateFileId((int)$input[$key]); |
|
| 1257 | + $this->validateUserAccessToFile($input[$key], $userId); |
|
| 1258 | + $newInputOutput[$key] = $node; |
|
| 1259 | + } else { |
|
| 1260 | + // is list |
|
| 1261 | + $newInputOutput[$key] = []; |
|
| 1262 | + foreach ($input[$key] as $item) { |
|
| 1263 | + $node = $this->validateFileId((int)$item); |
|
| 1264 | + $this->validateUserAccessToFile($item, $userId); |
|
| 1265 | + $newInputOutput[$key][] = $node; |
|
| 1266 | + } |
|
| 1267 | + } |
|
| 1268 | + } |
|
| 1269 | + return $newInputOutput; |
|
| 1270 | + } |
|
| 1271 | + |
|
| 1272 | + public function getUserTask(int $id, ?string $userId): Task { |
|
| 1273 | + try { |
|
| 1274 | + $taskEntity = $this->taskMapper->findByIdAndUser($id, $userId); |
|
| 1275 | + return $taskEntity->toPublicTask(); |
|
| 1276 | + } catch (DoesNotExistException $e) { |
|
| 1277 | + throw new \OCP\TaskProcessing\Exception\NotFoundException('Could not find the task', 0, $e); |
|
| 1278 | + } catch (MultipleObjectsReturnedException|\OCP\DB\Exception $e) { |
|
| 1279 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the task', 0, $e); |
|
| 1280 | + } catch (\JsonException $e) { |
|
| 1281 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the task', 0, $e); |
|
| 1282 | + } |
|
| 1283 | + } |
|
| 1284 | + |
|
| 1285 | + public function getUserTasks(?string $userId, ?string $taskTypeId = null, ?string $customId = null): array { |
|
| 1286 | + try { |
|
| 1287 | + $taskEntities = $this->taskMapper->findByUserAndTaskType($userId, $taskTypeId, $customId); |
|
| 1288 | + return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities); |
|
| 1289 | + } catch (\OCP\DB\Exception $e) { |
|
| 1290 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the tasks', 0, $e); |
|
| 1291 | + } catch (\JsonException $e) { |
|
| 1292 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the tasks', 0, $e); |
|
| 1293 | + } |
|
| 1294 | + } |
|
| 1295 | + |
|
| 1296 | + public function getTasks( |
|
| 1297 | + ?string $userId, ?string $taskTypeId = null, ?string $appId = null, ?string $customId = null, |
|
| 1298 | + ?int $status = null, ?int $scheduleAfter = null, ?int $endedBefore = null, |
|
| 1299 | + ): array { |
|
| 1300 | + try { |
|
| 1301 | + $taskEntities = $this->taskMapper->findTasks($userId, $taskTypeId, $appId, $customId, $status, $scheduleAfter, $endedBefore); |
|
| 1302 | + return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities); |
|
| 1303 | + } catch (\OCP\DB\Exception $e) { |
|
| 1304 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding the tasks', 0, $e); |
|
| 1305 | + } catch (\JsonException $e) { |
|
| 1306 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding the tasks', 0, $e); |
|
| 1307 | + } |
|
| 1308 | + } |
|
| 1309 | + |
|
| 1310 | + public function getUserTasksByApp(?string $userId, string $appId, ?string $customId = null): array { |
|
| 1311 | + try { |
|
| 1312 | + $taskEntities = $this->taskMapper->findUserTasksByApp($userId, $appId, $customId); |
|
| 1313 | + return array_map(fn ($taskEntity): Task => $taskEntity->toPublicTask(), $taskEntities); |
|
| 1314 | + } catch (\OCP\DB\Exception $e) { |
|
| 1315 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem finding a task', 0, $e); |
|
| 1316 | + } catch (\JsonException $e) { |
|
| 1317 | + throw new \OCP\TaskProcessing\Exception\Exception('There was a problem parsing JSON after finding a task', 0, $e); |
|
| 1318 | + } |
|
| 1319 | + } |
|
| 1320 | + |
|
| 1321 | + /** |
|
| 1322 | + *Takes task input or output and replaces base64 data with file ids |
|
| 1323 | + * |
|
| 1324 | + * @param array $output |
|
| 1325 | + * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep |
|
| 1326 | + * @return array |
|
| 1327 | + * @throws NotPermittedException |
|
| 1328 | + */ |
|
| 1329 | + public function encapsulateOutputFileData(array $output, ...$specs): array { |
|
| 1330 | + $newOutput = []; |
|
| 1331 | + try { |
|
| 1332 | + $folder = $this->appData->getFolder('TaskProcessing'); |
|
| 1333 | + } catch (\OCP\Files\NotFoundException) { |
|
| 1334 | + $folder = $this->appData->newFolder('TaskProcessing'); |
|
| 1335 | + } |
|
| 1336 | + $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []); |
|
| 1337 | + foreach ($spec as $key => $descriptor) { |
|
| 1338 | + $type = $descriptor->getShapeType(); |
|
| 1339 | + if (!isset($output[$key])) { |
|
| 1340 | + continue; |
|
| 1341 | + } |
|
| 1342 | + if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) { |
|
| 1343 | + $newOutput[$key] = $output[$key]; |
|
| 1344 | + continue; |
|
| 1345 | + } |
|
| 1346 | + if (EShapeType::getScalarType($type) === $type) { |
|
| 1347 | + /** @var SimpleFile $file */ |
|
| 1348 | + $file = $folder->newFile(time() . '-' . rand(1, 100000), $output[$key]); |
|
| 1349 | + $newOutput[$key] = $file->getId(); // polymorphic call to SimpleFile |
|
| 1350 | + } else { |
|
| 1351 | + $newOutput = []; |
|
| 1352 | + foreach ($output[$key] as $item) { |
|
| 1353 | + /** @var SimpleFile $file */ |
|
| 1354 | + $file = $folder->newFile(time() . '-' . rand(1, 100000), $item); |
|
| 1355 | + $newOutput[$key][] = $file->getId(); |
|
| 1356 | + } |
|
| 1357 | + } |
|
| 1358 | + } |
|
| 1359 | + return $newOutput; |
|
| 1360 | + } |
|
| 1361 | + |
|
| 1362 | + /** |
|
| 1363 | + * @param Task $task |
|
| 1364 | + * @return array<array-key, list<numeric|string|File>|numeric|string|File> |
|
| 1365 | + * @throws GenericFileException |
|
| 1366 | + * @throws LockedException |
|
| 1367 | + * @throws NotPermittedException |
|
| 1368 | + * @throws ValidationException|UnauthorizedException |
|
| 1369 | + */ |
|
| 1370 | + public function prepareInputData(Task $task): array { |
|
| 1371 | + $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1372 | + $inputShape = $taskTypes[$task->getTaskTypeId()]['inputShape']; |
|
| 1373 | + $optionalInputShape = $taskTypes[$task->getTaskTypeId()]['optionalInputShape']; |
|
| 1374 | + $input = $task->getInput(); |
|
| 1375 | + $input = $this->removeSuperfluousArrayKeys($input, $inputShape, $optionalInputShape); |
|
| 1376 | + $input = $this->fillInputFileData($task->getUserId(), $input, $inputShape, $optionalInputShape); |
|
| 1377 | + return $input; |
|
| 1378 | + } |
|
| 1379 | + |
|
| 1380 | + public function lockTask(Task $task): bool { |
|
| 1381 | + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1382 | + if ($this->taskMapper->lockTask($taskEntity) === 0) { |
|
| 1383 | + return false; |
|
| 1384 | + } |
|
| 1385 | + $task->setStatus(Task::STATUS_RUNNING); |
|
| 1386 | + return true; |
|
| 1387 | + } |
|
| 1388 | + |
|
| 1389 | + /** |
|
| 1390 | + * @throws \JsonException |
|
| 1391 | + * @throws Exception |
|
| 1392 | + */ |
|
| 1393 | + public function setTaskStatus(Task $task, int $status): void { |
|
| 1394 | + $currentTaskStatus = $task->getStatus(); |
|
| 1395 | + if ($currentTaskStatus === Task::STATUS_SCHEDULED && $status === Task::STATUS_RUNNING) { |
|
| 1396 | + $task->setStartedAt(time()); |
|
| 1397 | + } elseif ($currentTaskStatus === Task::STATUS_RUNNING && ($status === Task::STATUS_FAILED || $status === Task::STATUS_CANCELLED)) { |
|
| 1398 | + $task->setEndedAt(time()); |
|
| 1399 | + } elseif ($currentTaskStatus === Task::STATUS_UNKNOWN && $status === Task::STATUS_SCHEDULED) { |
|
| 1400 | + $task->setScheduledAt(time()); |
|
| 1401 | + } |
|
| 1402 | + $task->setStatus($status); |
|
| 1403 | + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1404 | + $this->taskMapper->update($taskEntity); |
|
| 1405 | + } |
|
| 1406 | + |
|
| 1407 | + /** |
|
| 1408 | + * Validate input, fill input default values, set completionExpectedAt, set scheduledAt |
|
| 1409 | + * |
|
| 1410 | + * @param Task $task |
|
| 1411 | + * @return void |
|
| 1412 | + * @throws UnauthorizedException |
|
| 1413 | + * @throws ValidationException |
|
| 1414 | + * @throws \OCP\TaskProcessing\Exception\Exception |
|
| 1415 | + */ |
|
| 1416 | + private function prepareTask(Task $task): void { |
|
| 1417 | + $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1418 | + $taskType = $taskTypes[$task->getTaskTypeId()]; |
|
| 1419 | + $inputShape = $taskType['inputShape']; |
|
| 1420 | + $inputShapeDefaults = $taskType['inputShapeDefaults']; |
|
| 1421 | + $inputShapeEnumValues = $taskType['inputShapeEnumValues']; |
|
| 1422 | + $optionalInputShape = $taskType['optionalInputShape']; |
|
| 1423 | + $optionalInputShapeEnumValues = $taskType['optionalInputShapeEnumValues']; |
|
| 1424 | + $optionalInputShapeDefaults = $taskType['optionalInputShapeDefaults']; |
|
| 1425 | + // validate input |
|
| 1426 | + $this->validateInput($inputShape, $inputShapeDefaults, $inputShapeEnumValues, $task->getInput()); |
|
| 1427 | + $this->validateInput($optionalInputShape, $optionalInputShapeDefaults, $optionalInputShapeEnumValues, $task->getInput(), true); |
|
| 1428 | + // authenticate access to mentioned files |
|
| 1429 | + $ids = []; |
|
| 1430 | + foreach ($inputShape + $optionalInputShape as $key => $descriptor) { |
|
| 1431 | + if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) { |
|
| 1432 | + /** @var list<int>|int $inputSlot */ |
|
| 1433 | + $inputSlot = $task->getInput()[$key]; |
|
| 1434 | + if (is_array($inputSlot)) { |
|
| 1435 | + $ids += $inputSlot; |
|
| 1436 | + } else { |
|
| 1437 | + $ids[] = $inputSlot; |
|
| 1438 | + } |
|
| 1439 | + } |
|
| 1440 | + } |
|
| 1441 | + foreach ($ids as $fileId) { |
|
| 1442 | + $this->validateFileId($fileId); |
|
| 1443 | + $this->validateUserAccessToFile($fileId, $task->getUserId()); |
|
| 1444 | + } |
|
| 1445 | + // remove superfluous keys and set input |
|
| 1446 | + $input = $this->removeSuperfluousArrayKeys($task->getInput(), $inputShape, $optionalInputShape); |
|
| 1447 | + $inputWithDefaults = $this->fillInputDefaults($input, $inputShapeDefaults, $optionalInputShapeDefaults); |
|
| 1448 | + $task->setInput($inputWithDefaults); |
|
| 1449 | + $task->setScheduledAt(time()); |
|
| 1450 | + $provider = $this->getPreferredProvider($task->getTaskTypeId()); |
|
| 1451 | + // calculate expected completion time |
|
| 1452 | + $completionExpectedAt = new \DateTime('now'); |
|
| 1453 | + $completionExpectedAt->add(new \DateInterval('PT' . $provider->getExpectedRuntime() . 'S')); |
|
| 1454 | + $task->setCompletionExpectedAt($completionExpectedAt); |
|
| 1455 | + } |
|
| 1456 | + |
|
| 1457 | + /** |
|
| 1458 | + * Store the task in the DB and set its ID in the \OCP\TaskProcessing\Task input param |
|
| 1459 | + * |
|
| 1460 | + * @param Task $task |
|
| 1461 | + * @return void |
|
| 1462 | + * @throws Exception |
|
| 1463 | + * @throws \JsonException |
|
| 1464 | + */ |
|
| 1465 | + private function storeTask(Task $task): void { |
|
| 1466 | + // create a db entity and insert into db table |
|
| 1467 | + $taskEntity = \OC\TaskProcessing\Db\Task::fromPublicTask($task); |
|
| 1468 | + $this->taskMapper->insert($taskEntity); |
|
| 1469 | + // make sure the scheduler knows the id |
|
| 1470 | + $task->setId($taskEntity->getId()); |
|
| 1471 | + } |
|
| 1472 | + |
|
| 1473 | + /** |
|
| 1474 | + * @param array $output |
|
| 1475 | + * @param ShapeDescriptor[] ...$specs the specs that define which keys to keep |
|
| 1476 | + * @return array |
|
| 1477 | + * @throws NotPermittedException |
|
| 1478 | + */ |
|
| 1479 | + private function validateOutputFileIds(array $output, ...$specs): array { |
|
| 1480 | + $newOutput = []; |
|
| 1481 | + $spec = array_reduce($specs, fn ($carry, $spec) => $carry + $spec, []); |
|
| 1482 | + foreach ($spec as $key => $descriptor) { |
|
| 1483 | + $type = $descriptor->getShapeType(); |
|
| 1484 | + if (!isset($output[$key])) { |
|
| 1485 | + continue; |
|
| 1486 | + } |
|
| 1487 | + if (!in_array(EShapeType::getScalarType($type), [EShapeType::Image, EShapeType::Audio, EShapeType::Video, EShapeType::File], true)) { |
|
| 1488 | + $newOutput[$key] = $output[$key]; |
|
| 1489 | + continue; |
|
| 1490 | + } |
|
| 1491 | + if (EShapeType::getScalarType($type) === $type) { |
|
| 1492 | + // Is scalar file ID |
|
| 1493 | + $newOutput[$key] = $this->validateFileId($output[$key]); |
|
| 1494 | + } else { |
|
| 1495 | + // Is list of file IDs |
|
| 1496 | + $newOutput = []; |
|
| 1497 | + foreach ($output[$key] as $item) { |
|
| 1498 | + $newOutput[$key][] = $this->validateFileId($item); |
|
| 1499 | + } |
|
| 1500 | + } |
|
| 1501 | + } |
|
| 1502 | + return $newOutput; |
|
| 1503 | + } |
|
| 1504 | + |
|
| 1505 | + /** |
|
| 1506 | + * @param mixed $id |
|
| 1507 | + * @return File |
|
| 1508 | + * @throws ValidationException |
|
| 1509 | + */ |
|
| 1510 | + private function validateFileId(mixed $id): File { |
|
| 1511 | + $node = $this->rootFolder->getFirstNodeById($id); |
|
| 1512 | + if ($node === null) { |
|
| 1513 | + $node = $this->rootFolder->getFirstNodeByIdInPath($id, '/' . $this->rootFolder->getAppDataDirectoryName() . '/'); |
|
| 1514 | + if ($node === null) { |
|
| 1515 | + throw new ValidationException('Could not find file ' . $id); |
|
| 1516 | + } elseif (!$node instanceof File) { |
|
| 1517 | + throw new ValidationException('File with id "' . $id . '" is not a file'); |
|
| 1518 | + } |
|
| 1519 | + } elseif (!$node instanceof File) { |
|
| 1520 | + throw new ValidationException('File with id "' . $id . '" is not a file'); |
|
| 1521 | + } |
|
| 1522 | + return $node; |
|
| 1523 | + } |
|
| 1524 | + |
|
| 1525 | + /** |
|
| 1526 | + * @param mixed $fileId |
|
| 1527 | + * @param string|null $userId |
|
| 1528 | + * @return void |
|
| 1529 | + * @throws UnauthorizedException |
|
| 1530 | + */ |
|
| 1531 | + private function validateUserAccessToFile(mixed $fileId, ?string $userId): void { |
|
| 1532 | + if ($userId === null) { |
|
| 1533 | + throw new UnauthorizedException('User does not have access to file ' . $fileId); |
|
| 1534 | + } |
|
| 1535 | + $mounts = $this->userMountCache->getMountsForFileId($fileId); |
|
| 1536 | + $userIds = array_map(fn ($mount) => $mount->getUser()->getUID(), $mounts); |
|
| 1537 | + if (!in_array($userId, $userIds)) { |
|
| 1538 | + throw new UnauthorizedException('User ' . $userId . ' does not have access to file ' . $fileId); |
|
| 1539 | + } |
|
| 1540 | + } |
|
| 1541 | + |
|
| 1542 | + /** |
|
| 1543 | + * @param Task $task |
|
| 1544 | + * @return list<int> |
|
| 1545 | + * @throws NotFoundException |
|
| 1546 | + */ |
|
| 1547 | + public function extractFileIdsFromTask(Task $task): array { |
|
| 1548 | + $ids = []; |
|
| 1549 | + $taskTypes = $this->getAvailableTaskTypes(); |
|
| 1550 | + if (!isset($taskTypes[$task->getTaskTypeId()])) { |
|
| 1551 | + throw new NotFoundException('Could not find task type'); |
|
| 1552 | + } |
|
| 1553 | + $taskType = $taskTypes[$task->getTaskTypeId()]; |
|
| 1554 | + foreach ($taskType['inputShape'] + $taskType['optionalInputShape'] as $key => $descriptor) { |
|
| 1555 | + if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) { |
|
| 1556 | + /** @var int|list<int> $inputSlot */ |
|
| 1557 | + $inputSlot = $task->getInput()[$key]; |
|
| 1558 | + if (is_array($inputSlot)) { |
|
| 1559 | + $ids = array_merge($inputSlot, $ids); |
|
| 1560 | + } else { |
|
| 1561 | + $ids[] = $inputSlot; |
|
| 1562 | + } |
|
| 1563 | + } |
|
| 1564 | + } |
|
| 1565 | + if ($task->getOutput() !== null) { |
|
| 1566 | + foreach ($taskType['outputShape'] + $taskType['optionalOutputShape'] as $key => $descriptor) { |
|
| 1567 | + if (in_array(EShapeType::getScalarType($descriptor->getShapeType()), [EShapeType::File, EShapeType::Image, EShapeType::Audio, EShapeType::Video], true)) { |
|
| 1568 | + /** @var int|list<int> $outputSlot */ |
|
| 1569 | + $outputSlot = $task->getOutput()[$key]; |
|
| 1570 | + if (is_array($outputSlot)) { |
|
| 1571 | + $ids = array_merge($outputSlot, $ids); |
|
| 1572 | + } else { |
|
| 1573 | + $ids[] = $outputSlot; |
|
| 1574 | + } |
|
| 1575 | + } |
|
| 1576 | + } |
|
| 1577 | + } |
|
| 1578 | + return $ids; |
|
| 1579 | + } |
|
| 1580 | + |
|
| 1581 | + /** |
|
| 1582 | + * @param ISimpleFolder $folder |
|
| 1583 | + * @param int $ageInSeconds |
|
| 1584 | + * @return \Generator |
|
| 1585 | + */ |
|
| 1586 | + public function clearFilesOlderThan(ISimpleFolder $folder, int $ageInSeconds = self::MAX_TASK_AGE_SECONDS): \Generator { |
|
| 1587 | + foreach ($folder->getDirectoryListing() as $file) { |
|
| 1588 | + if ($file->getMTime() < time() - $ageInSeconds) { |
|
| 1589 | + try { |
|
| 1590 | + $fileName = $file->getName(); |
|
| 1591 | + $file->delete(); |
|
| 1592 | + yield $fileName; |
|
| 1593 | + } catch (NotPermittedException $e) { |
|
| 1594 | + $this->logger->warning('Failed to delete a stale task processing file', ['exception' => $e]); |
|
| 1595 | + } |
|
| 1596 | + } |
|
| 1597 | + } |
|
| 1598 | + } |
|
| 1599 | + |
|
| 1600 | + /** |
|
| 1601 | + * @param int $ageInSeconds |
|
| 1602 | + * @return \Generator |
|
| 1603 | + * @throws Exception |
|
| 1604 | + * @throws InvalidPathException |
|
| 1605 | + * @throws NotFoundException |
|
| 1606 | + * @throws \JsonException |
|
| 1607 | + * @throws \OCP\Files\NotFoundException |
|
| 1608 | + */ |
|
| 1609 | + public function cleanupTaskProcessingTaskFiles(int $ageInSeconds = self::MAX_TASK_AGE_SECONDS): \Generator { |
|
| 1610 | + $taskIdsToCleanup = []; |
|
| 1611 | + foreach ($this->taskMapper->getTasksToCleanup($ageInSeconds) as $task) { |
|
| 1612 | + $taskIdsToCleanup[] = $task->getId(); |
|
| 1613 | + $ocpTask = $task->toPublicTask(); |
|
| 1614 | + $fileIds = $this->extractFileIdsFromTask($ocpTask); |
|
| 1615 | + foreach ($fileIds as $fileId) { |
|
| 1616 | + // only look for output files stored in appData/TaskProcessing/ |
|
| 1617 | + $file = $this->rootFolder->getFirstNodeByIdInPath($fileId, '/' . $this->rootFolder->getAppDataDirectoryName() . '/core/TaskProcessing/'); |
|
| 1618 | + if ($file instanceof File) { |
|
| 1619 | + try { |
|
| 1620 | + $fileId = $file->getId(); |
|
| 1621 | + $fileName = $file->getName(); |
|
| 1622 | + $file->delete(); |
|
| 1623 | + yield ['task_id' => $task->getId(), 'file_id' => $fileId, 'file_name' => $fileName]; |
|
| 1624 | + } catch (NotPermittedException $e) { |
|
| 1625 | + $this->logger->warning('Failed to delete a stale task processing file', ['exception' => $e]); |
|
| 1626 | + } |
|
| 1627 | + } |
|
| 1628 | + } |
|
| 1629 | + } |
|
| 1630 | + return $taskIdsToCleanup; |
|
| 1631 | + } |
|
| 1632 | + |
|
| 1633 | + /** |
|
| 1634 | + * Make a request to the task's webhookUri if necessary |
|
| 1635 | + * |
|
| 1636 | + * @param Task $task |
|
| 1637 | + */ |
|
| 1638 | + private function runWebhook(Task $task): void { |
|
| 1639 | + $uri = $task->getWebhookUri(); |
|
| 1640 | + $method = $task->getWebhookMethod(); |
|
| 1641 | + |
|
| 1642 | + if (!$uri || !$method) { |
|
| 1643 | + return; |
|
| 1644 | + } |
|
| 1645 | + |
|
| 1646 | + if (in_array($method, ['HTTP:GET', 'HTTP:POST', 'HTTP:PUT', 'HTTP:DELETE'], true)) { |
|
| 1647 | + $client = $this->clientService->newClient(); |
|
| 1648 | + $httpMethod = preg_replace('/^HTTP:/', '', $method); |
|
| 1649 | + $options = [ |
|
| 1650 | + 'timeout' => 30, |
|
| 1651 | + 'body' => json_encode([ |
|
| 1652 | + 'task' => $task->jsonSerialize(), |
|
| 1653 | + ]), |
|
| 1654 | + 'headers' => ['Content-Type' => 'application/json'], |
|
| 1655 | + ]; |
|
| 1656 | + try { |
|
| 1657 | + $client->request($httpMethod, $uri, $options); |
|
| 1658 | + } catch (ClientException|ServerException $e) { |
|
| 1659 | + $this->logger->warning('Task processing HTTP webhook failed for task ' . $task->getId() . '. Request failed', ['exception' => $e]); |
|
| 1660 | + } catch (\Exception|\Throwable $e) { |
|
| 1661 | + $this->logger->warning('Task processing HTTP webhook failed for task ' . $task->getId() . '. Unknown error', ['exception' => $e]); |
|
| 1662 | + } |
|
| 1663 | + } elseif (str_starts_with($method, 'AppAPI:') && str_starts_with($uri, '/')) { |
|
| 1664 | + $parsedMethod = explode(':', $method, 4); |
|
| 1665 | + if (count($parsedMethod) < 3) { |
|
| 1666 | + $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. Invalid method: ' . $method); |
|
| 1667 | + } |
|
| 1668 | + [, $exAppId, $httpMethod] = $parsedMethod; |
|
| 1669 | + if (!$this->appManager->isEnabledForAnyone('app_api')) { |
|
| 1670 | + $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. AppAPI is disabled or not installed.'); |
|
| 1671 | + return; |
|
| 1672 | + } |
|
| 1673 | + try { |
|
| 1674 | + $appApiFunctions = \OCP\Server::get(\OCA\AppAPI\PublicFunctions::class); |
|
| 1675 | + } catch (ContainerExceptionInterface|NotFoundExceptionInterface) { |
|
| 1676 | + $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. Could not get AppAPI public functions.'); |
|
| 1677 | + return; |
|
| 1678 | + } |
|
| 1679 | + $exApp = $appApiFunctions->getExApp($exAppId); |
|
| 1680 | + if ($exApp === null) { |
|
| 1681 | + $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. ExApp ' . $exAppId . ' is missing.'); |
|
| 1682 | + return; |
|
| 1683 | + } elseif (!$exApp['enabled']) { |
|
| 1684 | + $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. ExApp ' . $exAppId . ' is disabled.'); |
|
| 1685 | + return; |
|
| 1686 | + } |
|
| 1687 | + $requestParams = [ |
|
| 1688 | + 'task' => $task->jsonSerialize(), |
|
| 1689 | + ]; |
|
| 1690 | + $requestOptions = [ |
|
| 1691 | + 'timeout' => 30, |
|
| 1692 | + ]; |
|
| 1693 | + $response = $appApiFunctions->exAppRequest($exAppId, $uri, $task->getUserId(), $httpMethod, $requestParams, $requestOptions); |
|
| 1694 | + if (is_array($response) && isset($response['error'])) { |
|
| 1695 | + $this->logger->warning('Task processing AppAPI webhook failed for task ' . $task->getId() . '. Error during request to ExApp(' . $exAppId . '): ', $response['error']); |
|
| 1696 | + } |
|
| 1697 | + } |
|
| 1698 | + } |
|
| 1699 | 1699 | } |
@@ -59,1351 +59,1351 @@ |
||
| 59 | 59 | use Test\BackgroundJob\DummyJobList; |
| 60 | 60 | |
| 61 | 61 | class AudioToImage implements ITaskType { |
| 62 | - public const ID = 'test:audiotoimage'; |
|
| 63 | - |
|
| 64 | - public function getId(): string { |
|
| 65 | - return self::ID; |
|
| 66 | - } |
|
| 67 | - |
|
| 68 | - public function getName(): string { |
|
| 69 | - return self::class; |
|
| 70 | - } |
|
| 71 | - |
|
| 72 | - public function getDescription(): string { |
|
| 73 | - return self::class; |
|
| 74 | - } |
|
| 75 | - |
|
| 76 | - public function getInputShape(): array { |
|
| 77 | - return [ |
|
| 78 | - 'audio' => new ShapeDescriptor('Audio', 'The audio', EShapeType::Audio), |
|
| 79 | - ]; |
|
| 80 | - } |
|
| 81 | - |
|
| 82 | - public function getOutputShape(): array { |
|
| 83 | - return [ |
|
| 84 | - 'spectrogram' => new ShapeDescriptor('Spectrogram', 'The audio spectrogram', EShapeType::Image), |
|
| 85 | - ]; |
|
| 86 | - } |
|
| 62 | + public const ID = 'test:audiotoimage'; |
|
| 63 | + |
|
| 64 | + public function getId(): string { |
|
| 65 | + return self::ID; |
|
| 66 | + } |
|
| 67 | + |
|
| 68 | + public function getName(): string { |
|
| 69 | + return self::class; |
|
| 70 | + } |
|
| 71 | + |
|
| 72 | + public function getDescription(): string { |
|
| 73 | + return self::class; |
|
| 74 | + } |
|
| 75 | + |
|
| 76 | + public function getInputShape(): array { |
|
| 77 | + return [ |
|
| 78 | + 'audio' => new ShapeDescriptor('Audio', 'The audio', EShapeType::Audio), |
|
| 79 | + ]; |
|
| 80 | + } |
|
| 81 | + |
|
| 82 | + public function getOutputShape(): array { |
|
| 83 | + return [ |
|
| 84 | + 'spectrogram' => new ShapeDescriptor('Spectrogram', 'The audio spectrogram', EShapeType::Image), |
|
| 85 | + ]; |
|
| 86 | + } |
|
| 87 | 87 | } |
| 88 | 88 | |
| 89 | 89 | class AsyncProvider implements IProvider { |
| 90 | - public function getId(): string { |
|
| 91 | - return 'test:sync:success'; |
|
| 92 | - } |
|
| 93 | - |
|
| 94 | - public function getName(): string { |
|
| 95 | - return self::class; |
|
| 96 | - } |
|
| 97 | - |
|
| 98 | - public function getTaskTypeId(): string { |
|
| 99 | - return AudioToImage::ID; |
|
| 100 | - } |
|
| 101 | - |
|
| 102 | - public function getExpectedRuntime(): int { |
|
| 103 | - return 10; |
|
| 104 | - } |
|
| 105 | - |
|
| 106 | - public function getOptionalInputShape(): array { |
|
| 107 | - return [ |
|
| 108 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 109 | - ]; |
|
| 110 | - } |
|
| 111 | - |
|
| 112 | - public function getOptionalOutputShape(): array { |
|
| 113 | - return [ |
|
| 114 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 115 | - ]; |
|
| 116 | - } |
|
| 117 | - |
|
| 118 | - public function getInputShapeEnumValues(): array { |
|
| 119 | - return []; |
|
| 120 | - } |
|
| 121 | - |
|
| 122 | - public function getInputShapeDefaults(): array { |
|
| 123 | - return []; |
|
| 124 | - } |
|
| 125 | - |
|
| 126 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 127 | - return []; |
|
| 128 | - } |
|
| 129 | - |
|
| 130 | - public function getOptionalInputShapeDefaults(): array { |
|
| 131 | - return []; |
|
| 132 | - } |
|
| 133 | - |
|
| 134 | - public function getOutputShapeEnumValues(): array { |
|
| 135 | - return []; |
|
| 136 | - } |
|
| 137 | - |
|
| 138 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 139 | - return []; |
|
| 140 | - } |
|
| 90 | + public function getId(): string { |
|
| 91 | + return 'test:sync:success'; |
|
| 92 | + } |
|
| 93 | + |
|
| 94 | + public function getName(): string { |
|
| 95 | + return self::class; |
|
| 96 | + } |
|
| 97 | + |
|
| 98 | + public function getTaskTypeId(): string { |
|
| 99 | + return AudioToImage::ID; |
|
| 100 | + } |
|
| 101 | + |
|
| 102 | + public function getExpectedRuntime(): int { |
|
| 103 | + return 10; |
|
| 104 | + } |
|
| 105 | + |
|
| 106 | + public function getOptionalInputShape(): array { |
|
| 107 | + return [ |
|
| 108 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 109 | + ]; |
|
| 110 | + } |
|
| 111 | + |
|
| 112 | + public function getOptionalOutputShape(): array { |
|
| 113 | + return [ |
|
| 114 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 115 | + ]; |
|
| 116 | + } |
|
| 117 | + |
|
| 118 | + public function getInputShapeEnumValues(): array { |
|
| 119 | + return []; |
|
| 120 | + } |
|
| 121 | + |
|
| 122 | + public function getInputShapeDefaults(): array { |
|
| 123 | + return []; |
|
| 124 | + } |
|
| 125 | + |
|
| 126 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 127 | + return []; |
|
| 128 | + } |
|
| 129 | + |
|
| 130 | + public function getOptionalInputShapeDefaults(): array { |
|
| 131 | + return []; |
|
| 132 | + } |
|
| 133 | + |
|
| 134 | + public function getOutputShapeEnumValues(): array { |
|
| 135 | + return []; |
|
| 136 | + } |
|
| 137 | + |
|
| 138 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 139 | + return []; |
|
| 140 | + } |
|
| 141 | 141 | } |
| 142 | 142 | |
| 143 | 143 | class SuccessfulSyncProvider implements IProvider, ISynchronousProvider { |
| 144 | - public const ID = 'test:sync:success'; |
|
| 145 | - |
|
| 146 | - public function getId(): string { |
|
| 147 | - return self::ID; |
|
| 148 | - } |
|
| 149 | - |
|
| 150 | - public function getName(): string { |
|
| 151 | - return self::class; |
|
| 152 | - } |
|
| 153 | - |
|
| 154 | - public function getTaskTypeId(): string { |
|
| 155 | - return TextToText::ID; |
|
| 156 | - } |
|
| 157 | - |
|
| 158 | - public function getExpectedRuntime(): int { |
|
| 159 | - return 10; |
|
| 160 | - } |
|
| 161 | - |
|
| 162 | - public function getOptionalInputShape(): array { |
|
| 163 | - return [ |
|
| 164 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 165 | - ]; |
|
| 166 | - } |
|
| 167 | - |
|
| 168 | - public function getOptionalOutputShape(): array { |
|
| 169 | - return [ |
|
| 170 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 171 | - ]; |
|
| 172 | - } |
|
| 173 | - |
|
| 174 | - public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 175 | - return ['output' => $input['input']]; |
|
| 176 | - } |
|
| 177 | - |
|
| 178 | - public function getInputShapeEnumValues(): array { |
|
| 179 | - return []; |
|
| 180 | - } |
|
| 181 | - |
|
| 182 | - public function getInputShapeDefaults(): array { |
|
| 183 | - return []; |
|
| 184 | - } |
|
| 185 | - |
|
| 186 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 187 | - return []; |
|
| 188 | - } |
|
| 189 | - |
|
| 190 | - public function getOptionalInputShapeDefaults(): array { |
|
| 191 | - return []; |
|
| 192 | - } |
|
| 193 | - |
|
| 194 | - public function getOutputShapeEnumValues(): array { |
|
| 195 | - return []; |
|
| 196 | - } |
|
| 197 | - |
|
| 198 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 199 | - return []; |
|
| 200 | - } |
|
| 144 | + public const ID = 'test:sync:success'; |
|
| 145 | + |
|
| 146 | + public function getId(): string { |
|
| 147 | + return self::ID; |
|
| 148 | + } |
|
| 149 | + |
|
| 150 | + public function getName(): string { |
|
| 151 | + return self::class; |
|
| 152 | + } |
|
| 153 | + |
|
| 154 | + public function getTaskTypeId(): string { |
|
| 155 | + return TextToText::ID; |
|
| 156 | + } |
|
| 157 | + |
|
| 158 | + public function getExpectedRuntime(): int { |
|
| 159 | + return 10; |
|
| 160 | + } |
|
| 161 | + |
|
| 162 | + public function getOptionalInputShape(): array { |
|
| 163 | + return [ |
|
| 164 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 165 | + ]; |
|
| 166 | + } |
|
| 167 | + |
|
| 168 | + public function getOptionalOutputShape(): array { |
|
| 169 | + return [ |
|
| 170 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 171 | + ]; |
|
| 172 | + } |
|
| 173 | + |
|
| 174 | + public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 175 | + return ['output' => $input['input']]; |
|
| 176 | + } |
|
| 177 | + |
|
| 178 | + public function getInputShapeEnumValues(): array { |
|
| 179 | + return []; |
|
| 180 | + } |
|
| 181 | + |
|
| 182 | + public function getInputShapeDefaults(): array { |
|
| 183 | + return []; |
|
| 184 | + } |
|
| 185 | + |
|
| 186 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 187 | + return []; |
|
| 188 | + } |
|
| 189 | + |
|
| 190 | + public function getOptionalInputShapeDefaults(): array { |
|
| 191 | + return []; |
|
| 192 | + } |
|
| 193 | + |
|
| 194 | + public function getOutputShapeEnumValues(): array { |
|
| 195 | + return []; |
|
| 196 | + } |
|
| 197 | + |
|
| 198 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 199 | + return []; |
|
| 200 | + } |
|
| 201 | 201 | } |
| 202 | 202 | |
| 203 | 203 | |
| 204 | 204 | |
| 205 | 205 | class FailingSyncProvider implements IProvider, ISynchronousProvider { |
| 206 | - public const ERROR_MESSAGE = 'Failure'; |
|
| 207 | - public function getId(): string { |
|
| 208 | - return 'test:sync:fail'; |
|
| 209 | - } |
|
| 210 | - |
|
| 211 | - public function getName(): string { |
|
| 212 | - return self::class; |
|
| 213 | - } |
|
| 214 | - |
|
| 215 | - public function getTaskTypeId(): string { |
|
| 216 | - return TextToText::ID; |
|
| 217 | - } |
|
| 218 | - |
|
| 219 | - public function getExpectedRuntime(): int { |
|
| 220 | - return 10; |
|
| 221 | - } |
|
| 222 | - |
|
| 223 | - public function getOptionalInputShape(): array { |
|
| 224 | - return [ |
|
| 225 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 226 | - ]; |
|
| 227 | - } |
|
| 228 | - |
|
| 229 | - public function getOptionalOutputShape(): array { |
|
| 230 | - return [ |
|
| 231 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 232 | - ]; |
|
| 233 | - } |
|
| 234 | - |
|
| 235 | - public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 236 | - throw new ProcessingException(self::ERROR_MESSAGE); |
|
| 237 | - } |
|
| 238 | - |
|
| 239 | - public function getInputShapeEnumValues(): array { |
|
| 240 | - return []; |
|
| 241 | - } |
|
| 242 | - |
|
| 243 | - public function getInputShapeDefaults(): array { |
|
| 244 | - return []; |
|
| 245 | - } |
|
| 246 | - |
|
| 247 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 248 | - return []; |
|
| 249 | - } |
|
| 250 | - |
|
| 251 | - public function getOptionalInputShapeDefaults(): array { |
|
| 252 | - return []; |
|
| 253 | - } |
|
| 254 | - |
|
| 255 | - public function getOutputShapeEnumValues(): array { |
|
| 256 | - return []; |
|
| 257 | - } |
|
| 258 | - |
|
| 259 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 260 | - return []; |
|
| 261 | - } |
|
| 206 | + public const ERROR_MESSAGE = 'Failure'; |
|
| 207 | + public function getId(): string { |
|
| 208 | + return 'test:sync:fail'; |
|
| 209 | + } |
|
| 210 | + |
|
| 211 | + public function getName(): string { |
|
| 212 | + return self::class; |
|
| 213 | + } |
|
| 214 | + |
|
| 215 | + public function getTaskTypeId(): string { |
|
| 216 | + return TextToText::ID; |
|
| 217 | + } |
|
| 218 | + |
|
| 219 | + public function getExpectedRuntime(): int { |
|
| 220 | + return 10; |
|
| 221 | + } |
|
| 222 | + |
|
| 223 | + public function getOptionalInputShape(): array { |
|
| 224 | + return [ |
|
| 225 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 226 | + ]; |
|
| 227 | + } |
|
| 228 | + |
|
| 229 | + public function getOptionalOutputShape(): array { |
|
| 230 | + return [ |
|
| 231 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 232 | + ]; |
|
| 233 | + } |
|
| 234 | + |
|
| 235 | + public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 236 | + throw new ProcessingException(self::ERROR_MESSAGE); |
|
| 237 | + } |
|
| 238 | + |
|
| 239 | + public function getInputShapeEnumValues(): array { |
|
| 240 | + return []; |
|
| 241 | + } |
|
| 242 | + |
|
| 243 | + public function getInputShapeDefaults(): array { |
|
| 244 | + return []; |
|
| 245 | + } |
|
| 246 | + |
|
| 247 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 248 | + return []; |
|
| 249 | + } |
|
| 250 | + |
|
| 251 | + public function getOptionalInputShapeDefaults(): array { |
|
| 252 | + return []; |
|
| 253 | + } |
|
| 254 | + |
|
| 255 | + public function getOutputShapeEnumValues(): array { |
|
| 256 | + return []; |
|
| 257 | + } |
|
| 258 | + |
|
| 259 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 260 | + return []; |
|
| 261 | + } |
|
| 262 | 262 | } |
| 263 | 263 | |
| 264 | 264 | class BrokenSyncProvider implements IProvider, ISynchronousProvider { |
| 265 | - public function getId(): string { |
|
| 266 | - return 'test:sync:broken-output'; |
|
| 267 | - } |
|
| 268 | - |
|
| 269 | - public function getName(): string { |
|
| 270 | - return self::class; |
|
| 271 | - } |
|
| 272 | - |
|
| 273 | - public function getTaskTypeId(): string { |
|
| 274 | - return TextToText::ID; |
|
| 275 | - } |
|
| 276 | - |
|
| 277 | - public function getExpectedRuntime(): int { |
|
| 278 | - return 10; |
|
| 279 | - } |
|
| 280 | - |
|
| 281 | - public function getOptionalInputShape(): array { |
|
| 282 | - return [ |
|
| 283 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 284 | - ]; |
|
| 285 | - } |
|
| 286 | - |
|
| 287 | - public function getOptionalOutputShape(): array { |
|
| 288 | - return [ |
|
| 289 | - 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 290 | - ]; |
|
| 291 | - } |
|
| 292 | - |
|
| 293 | - public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 294 | - return []; |
|
| 295 | - } |
|
| 296 | - |
|
| 297 | - public function getInputShapeEnumValues(): array { |
|
| 298 | - return []; |
|
| 299 | - } |
|
| 300 | - |
|
| 301 | - public function getInputShapeDefaults(): array { |
|
| 302 | - return []; |
|
| 303 | - } |
|
| 304 | - |
|
| 305 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 306 | - return []; |
|
| 307 | - } |
|
| 308 | - |
|
| 309 | - public function getOptionalInputShapeDefaults(): array { |
|
| 310 | - return []; |
|
| 311 | - } |
|
| 312 | - |
|
| 313 | - public function getOutputShapeEnumValues(): array { |
|
| 314 | - return []; |
|
| 315 | - } |
|
| 316 | - |
|
| 317 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 318 | - return []; |
|
| 319 | - } |
|
| 265 | + public function getId(): string { |
|
| 266 | + return 'test:sync:broken-output'; |
|
| 267 | + } |
|
| 268 | + |
|
| 269 | + public function getName(): string { |
|
| 270 | + return self::class; |
|
| 271 | + } |
|
| 272 | + |
|
| 273 | + public function getTaskTypeId(): string { |
|
| 274 | + return TextToText::ID; |
|
| 275 | + } |
|
| 276 | + |
|
| 277 | + public function getExpectedRuntime(): int { |
|
| 278 | + return 10; |
|
| 279 | + } |
|
| 280 | + |
|
| 281 | + public function getOptionalInputShape(): array { |
|
| 282 | + return [ |
|
| 283 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 284 | + ]; |
|
| 285 | + } |
|
| 286 | + |
|
| 287 | + public function getOptionalOutputShape(): array { |
|
| 288 | + return [ |
|
| 289 | + 'optionalKey' => new ShapeDescriptor('optional Key', 'AN optional key', EShapeType::Text), |
|
| 290 | + ]; |
|
| 291 | + } |
|
| 292 | + |
|
| 293 | + public function process(?string $userId, array $input, callable $reportProgress): array { |
|
| 294 | + return []; |
|
| 295 | + } |
|
| 296 | + |
|
| 297 | + public function getInputShapeEnumValues(): array { |
|
| 298 | + return []; |
|
| 299 | + } |
|
| 300 | + |
|
| 301 | + public function getInputShapeDefaults(): array { |
|
| 302 | + return []; |
|
| 303 | + } |
|
| 304 | + |
|
| 305 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 306 | + return []; |
|
| 307 | + } |
|
| 308 | + |
|
| 309 | + public function getOptionalInputShapeDefaults(): array { |
|
| 310 | + return []; |
|
| 311 | + } |
|
| 312 | + |
|
| 313 | + public function getOutputShapeEnumValues(): array { |
|
| 314 | + return []; |
|
| 315 | + } |
|
| 316 | + |
|
| 317 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 318 | + return []; |
|
| 319 | + } |
|
| 320 | 320 | } |
| 321 | 321 | |
| 322 | 322 | class SuccessfulTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { |
| 323 | - public bool $ran = false; |
|
| 323 | + public bool $ran = false; |
|
| 324 | 324 | |
| 325 | - public function getName(): string { |
|
| 326 | - return 'TEST Vanilla LLM Provider'; |
|
| 327 | - } |
|
| 325 | + public function getName(): string { |
|
| 326 | + return 'TEST Vanilla LLM Provider'; |
|
| 327 | + } |
|
| 328 | 328 | |
| 329 | - public function process(string $prompt): string { |
|
| 330 | - $this->ran = true; |
|
| 331 | - return $prompt . ' Summarize'; |
|
| 332 | - } |
|
| 329 | + public function process(string $prompt): string { |
|
| 330 | + $this->ran = true; |
|
| 331 | + return $prompt . ' Summarize'; |
|
| 332 | + } |
|
| 333 | 333 | |
| 334 | - public function getTaskType(): string { |
|
| 335 | - return SummaryTaskType::class; |
|
| 336 | - } |
|
| 334 | + public function getTaskType(): string { |
|
| 335 | + return SummaryTaskType::class; |
|
| 336 | + } |
|
| 337 | 337 | } |
| 338 | 338 | |
| 339 | 339 | class FailingTextProcessingSummaryProvider implements \OCP\TextProcessing\IProvider { |
| 340 | - public bool $ran = false; |
|
| 340 | + public bool $ran = false; |
|
| 341 | 341 | |
| 342 | - public function getName(): string { |
|
| 343 | - return 'TEST Vanilla LLM Provider'; |
|
| 344 | - } |
|
| 342 | + public function getName(): string { |
|
| 343 | + return 'TEST Vanilla LLM Provider'; |
|
| 344 | + } |
|
| 345 | 345 | |
| 346 | - public function process(string $prompt): string { |
|
| 347 | - $this->ran = true; |
|
| 348 | - throw new \Exception('ERROR'); |
|
| 349 | - } |
|
| 346 | + public function process(string $prompt): string { |
|
| 347 | + $this->ran = true; |
|
| 348 | + throw new \Exception('ERROR'); |
|
| 349 | + } |
|
| 350 | 350 | |
| 351 | - public function getTaskType(): string { |
|
| 352 | - return SummaryTaskType::class; |
|
| 353 | - } |
|
| 351 | + public function getTaskType(): string { |
|
| 352 | + return SummaryTaskType::class; |
|
| 353 | + } |
|
| 354 | 354 | } |
| 355 | 355 | |
| 356 | 356 | class SuccessfulTextToImageProvider implements \OCP\TextToImage\IProvider { |
| 357 | - public bool $ran = false; |
|
| 358 | - |
|
| 359 | - public function getId(): string { |
|
| 360 | - return 'test:successful'; |
|
| 361 | - } |
|
| 362 | - |
|
| 363 | - public function getName(): string { |
|
| 364 | - return 'TEST Provider'; |
|
| 365 | - } |
|
| 366 | - |
|
| 367 | - public function generate(string $prompt, array $resources): void { |
|
| 368 | - $this->ran = true; |
|
| 369 | - foreach ($resources as $resource) { |
|
| 370 | - fwrite($resource, 'test'); |
|
| 371 | - } |
|
| 372 | - } |
|
| 373 | - |
|
| 374 | - public function getExpectedRuntime(): int { |
|
| 375 | - return 1; |
|
| 376 | - } |
|
| 357 | + public bool $ran = false; |
|
| 358 | + |
|
| 359 | + public function getId(): string { |
|
| 360 | + return 'test:successful'; |
|
| 361 | + } |
|
| 362 | + |
|
| 363 | + public function getName(): string { |
|
| 364 | + return 'TEST Provider'; |
|
| 365 | + } |
|
| 366 | + |
|
| 367 | + public function generate(string $prompt, array $resources): void { |
|
| 368 | + $this->ran = true; |
|
| 369 | + foreach ($resources as $resource) { |
|
| 370 | + fwrite($resource, 'test'); |
|
| 371 | + } |
|
| 372 | + } |
|
| 373 | + |
|
| 374 | + public function getExpectedRuntime(): int { |
|
| 375 | + return 1; |
|
| 376 | + } |
|
| 377 | 377 | } |
| 378 | 378 | |
| 379 | 379 | class FailingTextToImageProvider implements \OCP\TextToImage\IProvider { |
| 380 | - public bool $ran = false; |
|
| 380 | + public bool $ran = false; |
|
| 381 | 381 | |
| 382 | - public function getId(): string { |
|
| 383 | - return 'test:failing'; |
|
| 384 | - } |
|
| 382 | + public function getId(): string { |
|
| 383 | + return 'test:failing'; |
|
| 384 | + } |
|
| 385 | 385 | |
| 386 | - public function getName(): string { |
|
| 387 | - return 'TEST Provider'; |
|
| 388 | - } |
|
| 386 | + public function getName(): string { |
|
| 387 | + return 'TEST Provider'; |
|
| 388 | + } |
|
| 389 | 389 | |
| 390 | - public function generate(string $prompt, array $resources): void { |
|
| 391 | - $this->ran = true; |
|
| 392 | - throw new \RuntimeException('ERROR'); |
|
| 393 | - } |
|
| 390 | + public function generate(string $prompt, array $resources): void { |
|
| 391 | + $this->ran = true; |
|
| 392 | + throw new \RuntimeException('ERROR'); |
|
| 393 | + } |
|
| 394 | 394 | |
| 395 | - public function getExpectedRuntime(): int { |
|
| 396 | - return 1; |
|
| 397 | - } |
|
| 395 | + public function getExpectedRuntime(): int { |
|
| 396 | + return 1; |
|
| 397 | + } |
|
| 398 | 398 | } |
| 399 | 399 | |
| 400 | 400 | class ExternalProvider implements IProvider { |
| 401 | - public const ID = 'event:external:provider'; |
|
| 402 | - public const TASK_TYPE_ID = 'event:external:tasktype'; |
|
| 403 | - |
|
| 404 | - public function getId(): string { |
|
| 405 | - return self::ID; |
|
| 406 | - } |
|
| 407 | - public function getName(): string { |
|
| 408 | - return 'External Provider via Event'; |
|
| 409 | - } |
|
| 410 | - public function getTaskTypeId(): string { |
|
| 411 | - return self::TASK_TYPE_ID; |
|
| 412 | - } |
|
| 413 | - public function getExpectedRuntime(): int { |
|
| 414 | - return 5; |
|
| 415 | - } |
|
| 416 | - public function getOptionalInputShape(): array { |
|
| 417 | - return []; |
|
| 418 | - } |
|
| 419 | - public function getOptionalOutputShape(): array { |
|
| 420 | - return []; |
|
| 421 | - } |
|
| 422 | - public function getInputShapeEnumValues(): array { |
|
| 423 | - return []; |
|
| 424 | - } |
|
| 425 | - public function getInputShapeDefaults(): array { |
|
| 426 | - return []; |
|
| 427 | - } |
|
| 428 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 429 | - return []; |
|
| 430 | - } |
|
| 431 | - public function getOptionalInputShapeDefaults(): array { |
|
| 432 | - return []; |
|
| 433 | - } |
|
| 434 | - public function getOutputShapeEnumValues(): array { |
|
| 435 | - return []; |
|
| 436 | - } |
|
| 437 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 438 | - return []; |
|
| 439 | - } |
|
| 401 | + public const ID = 'event:external:provider'; |
|
| 402 | + public const TASK_TYPE_ID = 'event:external:tasktype'; |
|
| 403 | + |
|
| 404 | + public function getId(): string { |
|
| 405 | + return self::ID; |
|
| 406 | + } |
|
| 407 | + public function getName(): string { |
|
| 408 | + return 'External Provider via Event'; |
|
| 409 | + } |
|
| 410 | + public function getTaskTypeId(): string { |
|
| 411 | + return self::TASK_TYPE_ID; |
|
| 412 | + } |
|
| 413 | + public function getExpectedRuntime(): int { |
|
| 414 | + return 5; |
|
| 415 | + } |
|
| 416 | + public function getOptionalInputShape(): array { |
|
| 417 | + return []; |
|
| 418 | + } |
|
| 419 | + public function getOptionalOutputShape(): array { |
|
| 420 | + return []; |
|
| 421 | + } |
|
| 422 | + public function getInputShapeEnumValues(): array { |
|
| 423 | + return []; |
|
| 424 | + } |
|
| 425 | + public function getInputShapeDefaults(): array { |
|
| 426 | + return []; |
|
| 427 | + } |
|
| 428 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 429 | + return []; |
|
| 430 | + } |
|
| 431 | + public function getOptionalInputShapeDefaults(): array { |
|
| 432 | + return []; |
|
| 433 | + } |
|
| 434 | + public function getOutputShapeEnumValues(): array { |
|
| 435 | + return []; |
|
| 436 | + } |
|
| 437 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 438 | + return []; |
|
| 439 | + } |
|
| 440 | 440 | } |
| 441 | 441 | |
| 442 | 442 | |
| 443 | 443 | class ExternalTriggerableProvider implements ITriggerableProvider { |
| 444 | - public const ID = 'event:external:provider:triggerable'; |
|
| 445 | - public const TASK_TYPE_ID = TextToText::ID; |
|
| 446 | - |
|
| 447 | - public function getId(): string { |
|
| 448 | - return self::ID; |
|
| 449 | - } |
|
| 450 | - public function getName(): string { |
|
| 451 | - return 'External Triggerable Provider via Event'; |
|
| 452 | - } |
|
| 453 | - |
|
| 454 | - public function getTaskTypeId(): string { |
|
| 455 | - return self::TASK_TYPE_ID; |
|
| 456 | - } |
|
| 457 | - |
|
| 458 | - public function trigger(): void { |
|
| 459 | - } |
|
| 460 | - public function getExpectedRuntime(): int { |
|
| 461 | - return 5; |
|
| 462 | - } |
|
| 463 | - public function getOptionalInputShape(): array { |
|
| 464 | - return []; |
|
| 465 | - } |
|
| 466 | - public function getOptionalOutputShape(): array { |
|
| 467 | - return []; |
|
| 468 | - } |
|
| 469 | - public function getInputShapeEnumValues(): array { |
|
| 470 | - return []; |
|
| 471 | - } |
|
| 472 | - public function getInputShapeDefaults(): array { |
|
| 473 | - return []; |
|
| 474 | - } |
|
| 475 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 476 | - return []; |
|
| 477 | - } |
|
| 478 | - public function getOptionalInputShapeDefaults(): array { |
|
| 479 | - return []; |
|
| 480 | - } |
|
| 481 | - public function getOutputShapeEnumValues(): array { |
|
| 482 | - return []; |
|
| 483 | - } |
|
| 484 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 485 | - return []; |
|
| 486 | - } |
|
| 444 | + public const ID = 'event:external:provider:triggerable'; |
|
| 445 | + public const TASK_TYPE_ID = TextToText::ID; |
|
| 446 | + |
|
| 447 | + public function getId(): string { |
|
| 448 | + return self::ID; |
|
| 449 | + } |
|
| 450 | + public function getName(): string { |
|
| 451 | + return 'External Triggerable Provider via Event'; |
|
| 452 | + } |
|
| 453 | + |
|
| 454 | + public function getTaskTypeId(): string { |
|
| 455 | + return self::TASK_TYPE_ID; |
|
| 456 | + } |
|
| 457 | + |
|
| 458 | + public function trigger(): void { |
|
| 459 | + } |
|
| 460 | + public function getExpectedRuntime(): int { |
|
| 461 | + return 5; |
|
| 462 | + } |
|
| 463 | + public function getOptionalInputShape(): array { |
|
| 464 | + return []; |
|
| 465 | + } |
|
| 466 | + public function getOptionalOutputShape(): array { |
|
| 467 | + return []; |
|
| 468 | + } |
|
| 469 | + public function getInputShapeEnumValues(): array { |
|
| 470 | + return []; |
|
| 471 | + } |
|
| 472 | + public function getInputShapeDefaults(): array { |
|
| 473 | + return []; |
|
| 474 | + } |
|
| 475 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 476 | + return []; |
|
| 477 | + } |
|
| 478 | + public function getOptionalInputShapeDefaults(): array { |
|
| 479 | + return []; |
|
| 480 | + } |
|
| 481 | + public function getOutputShapeEnumValues(): array { |
|
| 482 | + return []; |
|
| 483 | + } |
|
| 484 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 485 | + return []; |
|
| 486 | + } |
|
| 487 | 487 | } |
| 488 | 488 | |
| 489 | 489 | class ConflictingExternalProvider implements IProvider { |
| 490 | - // Same ID as SuccessfulSyncProvider |
|
| 491 | - public const ID = 'test:sync:success'; |
|
| 492 | - public const TASK_TYPE_ID = 'event:external:tasktype'; // Can be different task type |
|
| 493 | - |
|
| 494 | - public function getId(): string { |
|
| 495 | - return self::ID; |
|
| 496 | - } |
|
| 497 | - public function getName(): string { |
|
| 498 | - return 'Conflicting External Provider'; |
|
| 499 | - } |
|
| 500 | - public function getTaskTypeId(): string { |
|
| 501 | - return self::TASK_TYPE_ID; |
|
| 502 | - } |
|
| 503 | - public function getExpectedRuntime(): int { |
|
| 504 | - return 50; |
|
| 505 | - } |
|
| 506 | - public function getOptionalInputShape(): array { |
|
| 507 | - return []; |
|
| 508 | - } |
|
| 509 | - public function getOptionalOutputShape(): array { |
|
| 510 | - return []; |
|
| 511 | - } |
|
| 512 | - public function getInputShapeEnumValues(): array { |
|
| 513 | - return []; |
|
| 514 | - } |
|
| 515 | - public function getInputShapeDefaults(): array { |
|
| 516 | - return []; |
|
| 517 | - } |
|
| 518 | - public function getOptionalInputShapeEnumValues(): array { |
|
| 519 | - return []; |
|
| 520 | - } |
|
| 521 | - public function getOptionalInputShapeDefaults(): array { |
|
| 522 | - return []; |
|
| 523 | - } |
|
| 524 | - public function getOutputShapeEnumValues(): array { |
|
| 525 | - return []; |
|
| 526 | - } |
|
| 527 | - public function getOptionalOutputShapeEnumValues(): array { |
|
| 528 | - return []; |
|
| 529 | - } |
|
| 490 | + // Same ID as SuccessfulSyncProvider |
|
| 491 | + public const ID = 'test:sync:success'; |
|
| 492 | + public const TASK_TYPE_ID = 'event:external:tasktype'; // Can be different task type |
|
| 493 | + |
|
| 494 | + public function getId(): string { |
|
| 495 | + return self::ID; |
|
| 496 | + } |
|
| 497 | + public function getName(): string { |
|
| 498 | + return 'Conflicting External Provider'; |
|
| 499 | + } |
|
| 500 | + public function getTaskTypeId(): string { |
|
| 501 | + return self::TASK_TYPE_ID; |
|
| 502 | + } |
|
| 503 | + public function getExpectedRuntime(): int { |
|
| 504 | + return 50; |
|
| 505 | + } |
|
| 506 | + public function getOptionalInputShape(): array { |
|
| 507 | + return []; |
|
| 508 | + } |
|
| 509 | + public function getOptionalOutputShape(): array { |
|
| 510 | + return []; |
|
| 511 | + } |
|
| 512 | + public function getInputShapeEnumValues(): array { |
|
| 513 | + return []; |
|
| 514 | + } |
|
| 515 | + public function getInputShapeDefaults(): array { |
|
| 516 | + return []; |
|
| 517 | + } |
|
| 518 | + public function getOptionalInputShapeEnumValues(): array { |
|
| 519 | + return []; |
|
| 520 | + } |
|
| 521 | + public function getOptionalInputShapeDefaults(): array { |
|
| 522 | + return []; |
|
| 523 | + } |
|
| 524 | + public function getOutputShapeEnumValues(): array { |
|
| 525 | + return []; |
|
| 526 | + } |
|
| 527 | + public function getOptionalOutputShapeEnumValues(): array { |
|
| 528 | + return []; |
|
| 529 | + } |
|
| 530 | 530 | } |
| 531 | 531 | |
| 532 | 532 | class ExternalTaskType implements ITaskType { |
| 533 | - public const ID = 'event:external:tasktype'; |
|
| 534 | - |
|
| 535 | - public function getId(): string { |
|
| 536 | - return self::ID; |
|
| 537 | - } |
|
| 538 | - public function getName(): string { |
|
| 539 | - return 'External Task Type via Event'; |
|
| 540 | - } |
|
| 541 | - public function getDescription(): string { |
|
| 542 | - return 'A task type added via event'; |
|
| 543 | - } |
|
| 544 | - public function getInputShape(): array { |
|
| 545 | - return ['external_input' => new ShapeDescriptor('Ext In', '', EShapeType::Text)]; |
|
| 546 | - } |
|
| 547 | - public function getOutputShape(): array { |
|
| 548 | - return ['external_output' => new ShapeDescriptor('Ext Out', '', EShapeType::Text)]; |
|
| 549 | - } |
|
| 533 | + public const ID = 'event:external:tasktype'; |
|
| 534 | + |
|
| 535 | + public function getId(): string { |
|
| 536 | + return self::ID; |
|
| 537 | + } |
|
| 538 | + public function getName(): string { |
|
| 539 | + return 'External Task Type via Event'; |
|
| 540 | + } |
|
| 541 | + public function getDescription(): string { |
|
| 542 | + return 'A task type added via event'; |
|
| 543 | + } |
|
| 544 | + public function getInputShape(): array { |
|
| 545 | + return ['external_input' => new ShapeDescriptor('Ext In', '', EShapeType::Text)]; |
|
| 546 | + } |
|
| 547 | + public function getOutputShape(): array { |
|
| 548 | + return ['external_output' => new ShapeDescriptor('Ext Out', '', EShapeType::Text)]; |
|
| 549 | + } |
|
| 550 | 550 | } |
| 551 | 551 | |
| 552 | 552 | class ConflictingExternalTaskType implements ITaskType { |
| 553 | - // Same ID as built-in TextToText |
|
| 554 | - public const ID = TextToText::ID; |
|
| 555 | - |
|
| 556 | - public function getId(): string { |
|
| 557 | - return self::ID; |
|
| 558 | - } |
|
| 559 | - public function getName(): string { |
|
| 560 | - return 'Conflicting External Task Type'; |
|
| 561 | - } |
|
| 562 | - public function getDescription(): string { |
|
| 563 | - return 'Overrides built-in TextToText'; |
|
| 564 | - } |
|
| 565 | - public function getInputShape(): array { |
|
| 566 | - return ['override_input' => new ShapeDescriptor('Override In', '', EShapeType::Number)]; |
|
| 567 | - } |
|
| 568 | - public function getOutputShape(): array { |
|
| 569 | - return ['override_output' => new ShapeDescriptor('Override Out', '', EShapeType::Number)]; |
|
| 570 | - } |
|
| 553 | + // Same ID as built-in TextToText |
|
| 554 | + public const ID = TextToText::ID; |
|
| 555 | + |
|
| 556 | + public function getId(): string { |
|
| 557 | + return self::ID; |
|
| 558 | + } |
|
| 559 | + public function getName(): string { |
|
| 560 | + return 'Conflicting External Task Type'; |
|
| 561 | + } |
|
| 562 | + public function getDescription(): string { |
|
| 563 | + return 'Overrides built-in TextToText'; |
|
| 564 | + } |
|
| 565 | + public function getInputShape(): array { |
|
| 566 | + return ['override_input' => new ShapeDescriptor('Override In', '', EShapeType::Number)]; |
|
| 567 | + } |
|
| 568 | + public function getOutputShape(): array { |
|
| 569 | + return ['override_output' => new ShapeDescriptor('Override Out', '', EShapeType::Number)]; |
|
| 570 | + } |
|
| 571 | 571 | } |
| 572 | 572 | |
| 573 | 573 | /** |
| 574 | 574 | * @group DB |
| 575 | 575 | */ |
| 576 | 576 | class TaskProcessingTest extends \Test\TestCase { |
| 577 | - private IManager $manager; |
|
| 578 | - private Coordinator $coordinator; |
|
| 579 | - private array $providers; |
|
| 580 | - private IServerContainer $serverContainer; |
|
| 581 | - private IEventDispatcher $eventDispatcher; |
|
| 582 | - private RegistrationContext $registrationContext; |
|
| 583 | - private TaskMapper $taskMapper; |
|
| 584 | - private IJobList $jobList; |
|
| 585 | - private IUserMountCache $userMountCache; |
|
| 586 | - private IRootFolder $rootFolder; |
|
| 587 | - private IConfig $config; |
|
| 588 | - private IAppConfig $appConfig; |
|
| 589 | - |
|
| 590 | - public const TEST_USER = 'testuser'; |
|
| 591 | - |
|
| 592 | - protected function setUp(): void { |
|
| 593 | - parent::setUp(); |
|
| 594 | - |
|
| 595 | - $this->providers = [ |
|
| 596 | - SuccessfulSyncProvider::class => new SuccessfulSyncProvider(), |
|
| 597 | - FailingSyncProvider::class => new FailingSyncProvider(), |
|
| 598 | - BrokenSyncProvider::class => new BrokenSyncProvider(), |
|
| 599 | - AsyncProvider::class => new AsyncProvider(), |
|
| 600 | - AudioToImage::class => new AudioToImage(), |
|
| 601 | - SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(), |
|
| 602 | - FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(), |
|
| 603 | - SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(), |
|
| 604 | - FailingTextToImageProvider::class => new FailingTextToImageProvider(), |
|
| 605 | - ExternalProvider::class => new ExternalProvider(), |
|
| 606 | - ExternalTriggerableProvider::class => new ExternalTriggerableProvider(), |
|
| 607 | - ConflictingExternalProvider::class => new ConflictingExternalProvider(), |
|
| 608 | - ExternalTaskType::class => new ExternalTaskType(), |
|
| 609 | - ConflictingExternalTaskType::class => new ConflictingExternalTaskType(), |
|
| 610 | - ]; |
|
| 611 | - |
|
| 612 | - $userManager = Server::get(IUserManager::class); |
|
| 613 | - if (!$userManager->userExists(self::TEST_USER)) { |
|
| 614 | - $userManager->createUser(self::TEST_USER, 'test'); |
|
| 615 | - } |
|
| 616 | - |
|
| 617 | - $this->serverContainer = $this->createMock(IServerContainer::class); |
|
| 618 | - $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) { |
|
| 619 | - return $this->providers[$class]; |
|
| 620 | - }); |
|
| 621 | - |
|
| 622 | - $this->eventDispatcher = new EventDispatcher( |
|
| 623 | - new \Symfony\Component\EventDispatcher\EventDispatcher(), |
|
| 624 | - $this->serverContainer, |
|
| 625 | - Server::get(LoggerInterface::class), |
|
| 626 | - ); |
|
| 627 | - |
|
| 628 | - $this->registrationContext = $this->createMock(RegistrationContext::class); |
|
| 629 | - $this->coordinator = $this->createMock(Coordinator::class); |
|
| 630 | - $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext); |
|
| 631 | - |
|
| 632 | - $this->rootFolder = Server::get(IRootFolder::class); |
|
| 633 | - |
|
| 634 | - $this->taskMapper = Server::get(TaskMapper::class); |
|
| 635 | - |
|
| 636 | - $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']); |
|
| 637 | - $this->jobList->expects($this->any())->method('add')->willReturnCallback(function (): void { |
|
| 638 | - }); |
|
| 639 | - |
|
| 640 | - $this->eventDispatcher = $this->createMock(IEventDispatcher::class); |
|
| 641 | - $this->configureEventDispatcherMock(); |
|
| 642 | - |
|
| 643 | - $text2imageManager = new \OC\TextToImage\Manager( |
|
| 644 | - $this->serverContainer, |
|
| 645 | - $this->coordinator, |
|
| 646 | - Server::get(LoggerInterface::class), |
|
| 647 | - $this->jobList, |
|
| 648 | - Server::get(\OC\TextToImage\Db\TaskMapper::class), |
|
| 649 | - Server::get(IConfig::class), |
|
| 650 | - Server::get(IAppDataFactory::class), |
|
| 651 | - ); |
|
| 652 | - |
|
| 653 | - $this->userMountCache = $this->createMock(IUserMountCache::class); |
|
| 654 | - $this->config = Server::get(IConfig::class); |
|
| 655 | - $this->appConfig = Server::get(IAppConfig::class); |
|
| 656 | - $this->manager = new Manager( |
|
| 657 | - $this->appConfig, |
|
| 658 | - $this->coordinator, |
|
| 659 | - $this->serverContainer, |
|
| 660 | - Server::get(LoggerInterface::class), |
|
| 661 | - $this->taskMapper, |
|
| 662 | - $this->jobList, |
|
| 663 | - $this->eventDispatcher, |
|
| 664 | - Server::get(IAppDataFactory::class), |
|
| 665 | - Server::get(IRootFolder::class), |
|
| 666 | - $text2imageManager, |
|
| 667 | - $this->userMountCache, |
|
| 668 | - Server::get(IClientService::class), |
|
| 669 | - Server::get(IAppManager::class), |
|
| 670 | - $userManager, |
|
| 671 | - Server::get(IUserSession::class), |
|
| 672 | - Server::get(ICacheFactory::class), |
|
| 673 | - Server::get(IFactory::class), |
|
| 674 | - ); |
|
| 675 | - } |
|
| 676 | - |
|
| 677 | - private function getFile(string $name, string $content): File { |
|
| 678 | - $folder = $this->rootFolder->getUserFolder(self::TEST_USER); |
|
| 679 | - $file = $folder->newFile($name, $content); |
|
| 680 | - return $file; |
|
| 681 | - } |
|
| 682 | - |
|
| 683 | - public function testShouldNotHaveAnyProviders(): void { |
|
| 684 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 685 | - self::assertCount(0, $this->manager->getAvailableTaskTypes()); |
|
| 686 | - self::assertCount(0, $this->manager->getAvailableTaskTypeIds()); |
|
| 687 | - self::assertFalse($this->manager->hasProviders()); |
|
| 688 | - self::expectException(PreConditionNotMetException::class); |
|
| 689 | - $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null)); |
|
| 690 | - } |
|
| 691 | - |
|
| 692 | - public function testProviderShouldBeRegisteredAndTaskTypeDisabled(): void { |
|
| 693 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 694 | - new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 695 | - ]); |
|
| 696 | - $taskProcessingTypeSettings = [ |
|
| 697 | - TextToText::ID => false, |
|
| 698 | - ]; |
|
| 699 | - $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true); |
|
| 700 | - self::assertCount(0, $this->manager->getAvailableTaskTypes()); |
|
| 701 | - self::assertCount(1, $this->manager->getAvailableTaskTypes(true)); |
|
| 702 | - self::assertCount(0, $this->manager->getAvailableTaskTypeIds()); |
|
| 703 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds(true)); |
|
| 704 | - self::assertTrue($this->manager->hasProviders()); |
|
| 705 | - self::expectException(PreConditionNotMetException::class); |
|
| 706 | - $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null)); |
|
| 707 | - } |
|
| 708 | - |
|
| 709 | - |
|
| 710 | - public function testProviderShouldBeRegisteredAndTaskFailValidation(): void { |
|
| 711 | - $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', '', lazy: true); |
|
| 712 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 713 | - new ServiceRegistration('test', BrokenSyncProvider::class) |
|
| 714 | - ]); |
|
| 715 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 716 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 717 | - self::assertTrue($this->manager->hasProviders()); |
|
| 718 | - $task = new Task(TextToText::ID, ['wrongInputKey' => 'Hello'], 'test', null); |
|
| 719 | - self::assertNull($task->getId()); |
|
| 720 | - self::expectException(ValidationException::class); |
|
| 721 | - $this->manager->scheduleTask($task); |
|
| 722 | - } |
|
| 723 | - |
|
| 724 | - public function testProviderShouldBeRegisteredAndTaskWithFilesFailValidation(): void { |
|
| 725 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 726 | - new ServiceRegistration('test', AudioToImage::class) |
|
| 727 | - ]); |
|
| 728 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 729 | - new ServiceRegistration('test', AsyncProvider::class) |
|
| 730 | - ]); |
|
| 731 | - $user = $this->createMock(IUser::class); |
|
| 732 | - $user->expects($this->any())->method('getUID')->willReturn(null); |
|
| 733 | - $mount = $this->createMock(ICachedMountInfo::class); |
|
| 734 | - $mount->expects($this->any())->method('getUser')->willReturn($user); |
|
| 735 | - $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); |
|
| 736 | - |
|
| 737 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 738 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 739 | - self::assertTrue($this->manager->hasProviders()); |
|
| 740 | - |
|
| 741 | - $audioId = $this->getFile('audioInput', 'Hello')->getId(); |
|
| 742 | - $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', null); |
|
| 743 | - self::assertNull($task->getId()); |
|
| 744 | - self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 745 | - self::expectException(UnauthorizedException::class); |
|
| 746 | - $this->manager->scheduleTask($task); |
|
| 747 | - } |
|
| 748 | - |
|
| 749 | - public function testProviderShouldBeRegisteredAndFail(): void { |
|
| 750 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 751 | - new ServiceRegistration('test', FailingSyncProvider::class) |
|
| 752 | - ]); |
|
| 753 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 754 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 755 | - self::assertTrue($this->manager->hasProviders()); |
|
| 756 | - $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 757 | - self::assertNull($task->getId()); |
|
| 758 | - self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 759 | - $this->manager->scheduleTask($task); |
|
| 760 | - self::assertNotNull($task->getId()); |
|
| 761 | - self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 762 | - |
|
| 763 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 764 | - |
|
| 765 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 766 | - Server::get(ITimeFactory::class), |
|
| 767 | - $this->manager, |
|
| 768 | - $this->jobList, |
|
| 769 | - Server::get(LoggerInterface::class), |
|
| 770 | - ); |
|
| 771 | - $backgroundJob->start($this->jobList); |
|
| 772 | - |
|
| 773 | - $task = $this->manager->getTask($task->getId()); |
|
| 774 | - self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 775 | - self::assertEquals(FailingSyncProvider::ERROR_MESSAGE, $task->getErrorMessage()); |
|
| 776 | - } |
|
| 777 | - |
|
| 778 | - public function testProviderShouldBeRegisteredAndFailOutputValidation(): void { |
|
| 779 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 780 | - new ServiceRegistration('test', BrokenSyncProvider::class) |
|
| 781 | - ]); |
|
| 782 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 783 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 784 | - self::assertTrue($this->manager->hasProviders()); |
|
| 785 | - $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 786 | - self::assertNull($task->getId()); |
|
| 787 | - self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 788 | - $this->manager->scheduleTask($task); |
|
| 789 | - self::assertNotNull($task->getId()); |
|
| 790 | - self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 791 | - |
|
| 792 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 793 | - |
|
| 794 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 795 | - Server::get(ITimeFactory::class), |
|
| 796 | - $this->manager, |
|
| 797 | - $this->jobList, |
|
| 798 | - Server::get(LoggerInterface::class), |
|
| 799 | - ); |
|
| 800 | - $backgroundJob->start($this->jobList); |
|
| 801 | - |
|
| 802 | - $task = $this->manager->getTask($task->getId()); |
|
| 803 | - self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 804 | - self::assertEquals('The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec', $task->getErrorMessage()); |
|
| 805 | - } |
|
| 806 | - |
|
| 807 | - public function testProviderShouldBeRegisteredAndRun(): void { |
|
| 808 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 809 | - new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 810 | - ]); |
|
| 811 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 812 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 813 | - $taskTypeStruct = $this->manager->getAvailableTaskTypes()[array_keys($this->manager->getAvailableTaskTypes())[0]]; |
|
| 814 | - self::assertTrue(isset($taskTypeStruct['inputShape']['input'])); |
|
| 815 | - self::assertEquals(EShapeType::Text, $taskTypeStruct['inputShape']['input']->getShapeType()); |
|
| 816 | - self::assertTrue(isset($taskTypeStruct['optionalInputShape']['optionalKey'])); |
|
| 817 | - self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalInputShape']['optionalKey']->getShapeType()); |
|
| 818 | - self::assertTrue(isset($taskTypeStruct['outputShape']['output'])); |
|
| 819 | - self::assertEquals(EShapeType::Text, $taskTypeStruct['outputShape']['output']->getShapeType()); |
|
| 820 | - self::assertTrue(isset($taskTypeStruct['optionalOutputShape']['optionalKey'])); |
|
| 821 | - self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalOutputShape']['optionalKey']->getShapeType()); |
|
| 822 | - |
|
| 823 | - self::assertTrue($this->manager->hasProviders()); |
|
| 824 | - $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 825 | - self::assertNull($task->getId()); |
|
| 826 | - self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 827 | - $this->manager->scheduleTask($task); |
|
| 828 | - self::assertNotNull($task->getId()); |
|
| 829 | - self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 830 | - |
|
| 831 | - // Task object retrieved from db is up-to-date |
|
| 832 | - $task2 = $this->manager->getTask($task->getId()); |
|
| 833 | - self::assertEquals($task->getId(), $task2->getId()); |
|
| 834 | - self::assertEquals(['input' => 'Hello'], $task2->getInput()); |
|
| 835 | - self::assertNull($task2->getOutput()); |
|
| 836 | - self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); |
|
| 837 | - |
|
| 838 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 839 | - |
|
| 840 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 841 | - Server::get(ITimeFactory::class), |
|
| 842 | - $this->manager, |
|
| 843 | - $this->jobList, |
|
| 844 | - Server::get(LoggerInterface::class), |
|
| 845 | - ); |
|
| 846 | - $backgroundJob->start($this->jobList); |
|
| 847 | - |
|
| 848 | - $task = $this->manager->getTask($task->getId()); |
|
| 849 | - self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is ' . $task->getStatus() . ' with error message: ' . $task->getErrorMessage()); |
|
| 850 | - self::assertEquals(['output' => 'Hello'], $task->getOutput()); |
|
| 851 | - self::assertEquals(1, $task->getProgress()); |
|
| 852 | - } |
|
| 853 | - |
|
| 854 | - public function testTaskTypeExplicitlyEnabled(): void { |
|
| 855 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 856 | - new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 857 | - ]); |
|
| 858 | - |
|
| 859 | - $taskProcessingTypeSettings = [ |
|
| 860 | - TextToText::ID => true, |
|
| 861 | - ]; |
|
| 862 | - $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true); |
|
| 863 | - |
|
| 864 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 865 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 866 | - |
|
| 867 | - self::assertTrue($this->manager->hasProviders()); |
|
| 868 | - $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 869 | - self::assertNull($task->getId()); |
|
| 870 | - self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 871 | - $this->manager->scheduleTask($task); |
|
| 872 | - self::assertNotNull($task->getId()); |
|
| 873 | - self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 874 | - |
|
| 875 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 876 | - |
|
| 877 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 878 | - Server::get(ITimeFactory::class), |
|
| 879 | - $this->manager, |
|
| 880 | - $this->jobList, |
|
| 881 | - Server::get(LoggerInterface::class), |
|
| 882 | - ); |
|
| 883 | - $backgroundJob->start($this->jobList); |
|
| 884 | - |
|
| 885 | - $task = $this->manager->getTask($task->getId()); |
|
| 886 | - self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is ' . $task->getStatus() . ' with error message: ' . $task->getErrorMessage()); |
|
| 887 | - self::assertEquals(['output' => 'Hello'], $task->getOutput()); |
|
| 888 | - self::assertEquals(1, $task->getProgress()); |
|
| 889 | - } |
|
| 890 | - |
|
| 891 | - public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningRawFileData(): void { |
|
| 892 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 893 | - new ServiceRegistration('test', AudioToImage::class) |
|
| 894 | - ]); |
|
| 895 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 896 | - new ServiceRegistration('test', AsyncProvider::class) |
|
| 897 | - ]); |
|
| 898 | - |
|
| 899 | - $user = $this->createMock(IUser::class); |
|
| 900 | - $user->expects($this->any())->method('getUID')->willReturn('testuser'); |
|
| 901 | - $mount = $this->createMock(ICachedMountInfo::class); |
|
| 902 | - $mount->expects($this->any())->method('getUser')->willReturn($user); |
|
| 903 | - $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); |
|
| 904 | - |
|
| 905 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 906 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 907 | - |
|
| 908 | - self::assertTrue($this->manager->hasProviders()); |
|
| 909 | - $audioId = $this->getFile('audioInput', 'Hello')->getId(); |
|
| 910 | - $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser'); |
|
| 911 | - self::assertNull($task->getId()); |
|
| 912 | - self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 913 | - $this->manager->scheduleTask($task); |
|
| 914 | - self::assertNotNull($task->getId()); |
|
| 915 | - self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 916 | - |
|
| 917 | - // Task object retrieved from db is up-to-date |
|
| 918 | - $task2 = $this->manager->getTask($task->getId()); |
|
| 919 | - self::assertEquals($task->getId(), $task2->getId()); |
|
| 920 | - self::assertEquals(['audio' => $audioId], $task2->getInput()); |
|
| 921 | - self::assertNull($task2->getOutput()); |
|
| 922 | - self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); |
|
| 923 | - |
|
| 924 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 925 | - |
|
| 926 | - $this->manager->setTaskProgress($task2->getId(), 0.1); |
|
| 927 | - $input = $this->manager->prepareInputData($task2); |
|
| 928 | - self::assertTrue(isset($input['audio'])); |
|
| 929 | - self::assertInstanceOf(File::class, $input['audio']); |
|
| 930 | - self::assertEquals($audioId, $input['audio']->getId()); |
|
| 931 | - |
|
| 932 | - $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => 'World']); |
|
| 933 | - |
|
| 934 | - $task = $this->manager->getTask($task->getId()); |
|
| 935 | - self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 936 | - self::assertEquals(1, $task->getProgress()); |
|
| 937 | - self::assertTrue(isset($task->getOutput()['spectrogram'])); |
|
| 938 | - $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['spectrogram'], '/' . $this->rootFolder->getAppDataDirectoryName() . '/'); |
|
| 939 | - self::assertNotNull($node); |
|
| 940 | - self::assertInstanceOf(File::class, $node); |
|
| 941 | - self::assertEquals('World', $node->getContent()); |
|
| 942 | - } |
|
| 943 | - |
|
| 944 | - public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningFileIds(): void { |
|
| 945 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 946 | - new ServiceRegistration('test', AudioToImage::class) |
|
| 947 | - ]); |
|
| 948 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 949 | - new ServiceRegistration('test', AsyncProvider::class) |
|
| 950 | - ]); |
|
| 951 | - $user = $this->createMock(IUser::class); |
|
| 952 | - $user->expects($this->any())->method('getUID')->willReturn('testuser'); |
|
| 953 | - $mount = $this->createMock(ICachedMountInfo::class); |
|
| 954 | - $mount->expects($this->any())->method('getUser')->willReturn($user); |
|
| 955 | - $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); |
|
| 956 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 957 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 958 | - |
|
| 959 | - self::assertTrue($this->manager->hasProviders()); |
|
| 960 | - $audioId = $this->getFile('audioInput', 'Hello')->getId(); |
|
| 961 | - $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser'); |
|
| 962 | - self::assertNull($task->getId()); |
|
| 963 | - self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 964 | - $this->manager->scheduleTask($task); |
|
| 965 | - self::assertNotNull($task->getId()); |
|
| 966 | - self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 967 | - |
|
| 968 | - // Task object retrieved from db is up-to-date |
|
| 969 | - $task2 = $this->manager->getTask($task->getId()); |
|
| 970 | - self::assertEquals($task->getId(), $task2->getId()); |
|
| 971 | - self::assertEquals(['audio' => $audioId], $task2->getInput()); |
|
| 972 | - self::assertNull($task2->getOutput()); |
|
| 973 | - self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); |
|
| 974 | - |
|
| 975 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 976 | - |
|
| 977 | - $this->manager->setTaskProgress($task2->getId(), 0.1); |
|
| 978 | - $input = $this->manager->prepareInputData($task2); |
|
| 979 | - self::assertTrue(isset($input['audio'])); |
|
| 980 | - self::assertInstanceOf(File::class, $input['audio']); |
|
| 981 | - self::assertEquals($audioId, $input['audio']->getId()); |
|
| 982 | - |
|
| 983 | - $outputFileId = $this->getFile('audioOutput', 'World')->getId(); |
|
| 984 | - |
|
| 985 | - $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => $outputFileId], true); |
|
| 986 | - |
|
| 987 | - $task = $this->manager->getTask($task->getId()); |
|
| 988 | - self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 989 | - self::assertEquals(1, $task->getProgress()); |
|
| 990 | - self::assertTrue(isset($task->getOutput()['spectrogram'])); |
|
| 991 | - $node = $this->rootFolder->getFirstNodeById($task->getOutput()['spectrogram']); |
|
| 992 | - self::assertNotNull($node, 'fileId:' . $task->getOutput()['spectrogram']); |
|
| 993 | - self::assertInstanceOf(File::class, $node); |
|
| 994 | - self::assertEquals('World', $node->getContent()); |
|
| 995 | - } |
|
| 996 | - |
|
| 997 | - public function testNonexistentTask(): void { |
|
| 998 | - $this->expectException(NotFoundException::class); |
|
| 999 | - $this->manager->getTask(2147483646); |
|
| 1000 | - } |
|
| 1001 | - |
|
| 1002 | - public function testOldTasksShouldBeCleanedUp(): void { |
|
| 1003 | - $currentTime = new \DateTime('now'); |
|
| 1004 | - $timeFactory = $this->createMock(ITimeFactory::class); |
|
| 1005 | - $timeFactory->expects($this->any())->method('getDateTime')->willReturnCallback(fn () => $currentTime); |
|
| 1006 | - $timeFactory->expects($this->any())->method('getTime')->willReturnCallback(fn () => $currentTime->getTimestamp()); |
|
| 1007 | - |
|
| 1008 | - $this->taskMapper = new TaskMapper( |
|
| 1009 | - Server::get(IDBConnection::class), |
|
| 1010 | - $timeFactory, |
|
| 1011 | - ); |
|
| 1012 | - |
|
| 1013 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1014 | - new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 1015 | - ]); |
|
| 1016 | - self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 1017 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1018 | - self::assertTrue($this->manager->hasProviders()); |
|
| 1019 | - $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 1020 | - $this->manager->scheduleTask($task); |
|
| 1021 | - |
|
| 1022 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 1023 | - |
|
| 1024 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 1025 | - Server::get(ITimeFactory::class), |
|
| 1026 | - $this->manager, |
|
| 1027 | - $this->jobList, |
|
| 1028 | - Server::get(LoggerInterface::class), |
|
| 1029 | - ); |
|
| 1030 | - $backgroundJob->start($this->jobList); |
|
| 1031 | - |
|
| 1032 | - $task = $this->manager->getTask($task->getId()); |
|
| 1033 | - |
|
| 1034 | - $currentTime = $currentTime->add(new \DateInterval('P1Y')); |
|
| 1035 | - // run background job |
|
| 1036 | - $bgJob = new RemoveOldTasksBackgroundJob( |
|
| 1037 | - $timeFactory, |
|
| 1038 | - $this->manager, |
|
| 1039 | - $this->taskMapper, |
|
| 1040 | - Server::get(LoggerInterface::class), |
|
| 1041 | - Server::get(IAppDataFactory::class), |
|
| 1042 | - ); |
|
| 1043 | - $bgJob->setArgument([]); |
|
| 1044 | - $bgJob->start($this->jobList); |
|
| 1045 | - |
|
| 1046 | - $this->expectException(NotFoundException::class); |
|
| 1047 | - $this->manager->getTask($task->getId()); |
|
| 1048 | - } |
|
| 1049 | - |
|
| 1050 | - public function testShouldTransparentlyHandleTextProcessingProviders(): void { |
|
| 1051 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ |
|
| 1052 | - new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class) |
|
| 1053 | - ]); |
|
| 1054 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1055 | - ]); |
|
| 1056 | - $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1057 | - self::assertCount(1, $taskTypes); |
|
| 1058 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1059 | - self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); |
|
| 1060 | - self::assertTrue($this->manager->hasProviders()); |
|
| 1061 | - $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); |
|
| 1062 | - $this->manager->scheduleTask($task); |
|
| 1063 | - |
|
| 1064 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 1065 | - |
|
| 1066 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 1067 | - Server::get(ITimeFactory::class), |
|
| 1068 | - $this->manager, |
|
| 1069 | - $this->jobList, |
|
| 1070 | - Server::get(LoggerInterface::class), |
|
| 1071 | - ); |
|
| 1072 | - $backgroundJob->start($this->jobList); |
|
| 1073 | - |
|
| 1074 | - $task = $this->manager->getTask($task->getId()); |
|
| 1075 | - self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 1076 | - self::assertIsArray($task->getOutput()); |
|
| 1077 | - self::assertTrue(isset($task->getOutput()['output'])); |
|
| 1078 | - self::assertEquals('Hello Summarize', $task->getOutput()['output']); |
|
| 1079 | - self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran); |
|
| 1080 | - } |
|
| 1081 | - |
|
| 1082 | - public function testShouldTransparentlyHandleFailingTextProcessingProviders(): void { |
|
| 1083 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ |
|
| 1084 | - new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class) |
|
| 1085 | - ]); |
|
| 1086 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1087 | - ]); |
|
| 1088 | - $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1089 | - self::assertCount(1, $taskTypes); |
|
| 1090 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1091 | - self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); |
|
| 1092 | - self::assertTrue($this->manager->hasProviders()); |
|
| 1093 | - $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); |
|
| 1094 | - $this->manager->scheduleTask($task); |
|
| 1095 | - |
|
| 1096 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 1097 | - |
|
| 1098 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 1099 | - Server::get(ITimeFactory::class), |
|
| 1100 | - $this->manager, |
|
| 1101 | - $this->jobList, |
|
| 1102 | - Server::get(LoggerInterface::class), |
|
| 1103 | - ); |
|
| 1104 | - $backgroundJob->start($this->jobList); |
|
| 1105 | - |
|
| 1106 | - $task = $this->manager->getTask($task->getId()); |
|
| 1107 | - self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 1108 | - self::assertTrue($task->getOutput() === null); |
|
| 1109 | - self::assertEquals('ERROR', $task->getErrorMessage()); |
|
| 1110 | - self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran); |
|
| 1111 | - } |
|
| 1112 | - |
|
| 1113 | - public function testShouldTransparentlyHandleText2ImageProviders(): void { |
|
| 1114 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ |
|
| 1115 | - new ServiceRegistration('test', SuccessfulTextToImageProvider::class) |
|
| 1116 | - ]); |
|
| 1117 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1118 | - ]); |
|
| 1119 | - $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1120 | - self::assertCount(1, $taskTypes); |
|
| 1121 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1122 | - self::assertTrue(isset($taskTypes[TextToImage::ID])); |
|
| 1123 | - self::assertTrue($this->manager->hasProviders()); |
|
| 1124 | - $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); |
|
| 1125 | - $this->manager->scheduleTask($task); |
|
| 1126 | - |
|
| 1127 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 1128 | - |
|
| 1129 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 1130 | - Server::get(ITimeFactory::class), |
|
| 1131 | - $this->manager, |
|
| 1132 | - $this->jobList, |
|
| 1133 | - Server::get(LoggerInterface::class), |
|
| 1134 | - ); |
|
| 1135 | - $backgroundJob->start($this->jobList); |
|
| 1136 | - |
|
| 1137 | - $task = $this->manager->getTask($task->getId()); |
|
| 1138 | - self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 1139 | - self::assertIsArray($task->getOutput()); |
|
| 1140 | - self::assertTrue(isset($task->getOutput()['images'])); |
|
| 1141 | - self::assertIsArray($task->getOutput()['images']); |
|
| 1142 | - self::assertCount(3, $task->getOutput()['images']); |
|
| 1143 | - self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran); |
|
| 1144 | - $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['images'][0], '/' . $this->rootFolder->getAppDataDirectoryName() . '/'); |
|
| 1145 | - self::assertNotNull($node); |
|
| 1146 | - self::assertInstanceOf(File::class, $node); |
|
| 1147 | - self::assertEquals('test', $node->getContent()); |
|
| 1148 | - } |
|
| 1149 | - |
|
| 1150 | - public function testShouldTransparentlyHandleFailingText2ImageProviders(): void { |
|
| 1151 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ |
|
| 1152 | - new ServiceRegistration('test', FailingTextToImageProvider::class) |
|
| 1153 | - ]); |
|
| 1154 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1155 | - ]); |
|
| 1156 | - $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1157 | - self::assertCount(1, $taskTypes); |
|
| 1158 | - self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1159 | - self::assertTrue(isset($taskTypes[TextToImage::ID])); |
|
| 1160 | - self::assertTrue($this->manager->hasProviders()); |
|
| 1161 | - $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); |
|
| 1162 | - $this->manager->scheduleTask($task); |
|
| 1163 | - |
|
| 1164 | - $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 1165 | - |
|
| 1166 | - $backgroundJob = new SynchronousBackgroundJob( |
|
| 1167 | - Server::get(ITimeFactory::class), |
|
| 1168 | - $this->manager, |
|
| 1169 | - $this->jobList, |
|
| 1170 | - Server::get(LoggerInterface::class), |
|
| 1171 | - ); |
|
| 1172 | - $backgroundJob->start($this->jobList); |
|
| 1173 | - |
|
| 1174 | - $task = $this->manager->getTask($task->getId()); |
|
| 1175 | - self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 1176 | - self::assertTrue($task->getOutput() === null); |
|
| 1177 | - self::assertEquals('ERROR', $task->getErrorMessage()); |
|
| 1178 | - self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran); |
|
| 1179 | - } |
|
| 1180 | - |
|
| 1181 | - public function testMergeProvidersLocalAndEvent() { |
|
| 1182 | - // Arrange: Local provider registered, DIFFERENT external provider via event |
|
| 1183 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1184 | - new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 1185 | - ]); |
|
| 1186 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1187 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1188 | - $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1189 | - |
|
| 1190 | - $externalProvider = new ExternalProvider(); // ID = 'event:external:provider' |
|
| 1191 | - $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1192 | - $this->manager = $this->createManagerInstance(); |
|
| 1193 | - |
|
| 1194 | - // Act |
|
| 1195 | - $providers = $this->manager->getProviders(); |
|
| 1196 | - |
|
| 1197 | - // Assert: Both providers should be present |
|
| 1198 | - self::assertArrayHasKey(SuccessfulSyncProvider::ID, $providers); |
|
| 1199 | - self::assertInstanceOf(SuccessfulSyncProvider::class, $providers[SuccessfulSyncProvider::ID]); |
|
| 1200 | - self::assertArrayHasKey(ExternalProvider::ID, $providers); |
|
| 1201 | - self::assertInstanceOf(ExternalProvider::class, $providers[ExternalProvider::ID]); |
|
| 1202 | - self::assertCount(2, $providers); |
|
| 1203 | - } |
|
| 1204 | - |
|
| 1205 | - public function testGetProvidersIncludesExternalViaEvent() { |
|
| 1206 | - // Arrange: No local providers, one external provider via event |
|
| 1207 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1208 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1209 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1210 | - $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1211 | - |
|
| 1212 | - |
|
| 1213 | - $externalProvider = new ExternalProvider(); |
|
| 1214 | - $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1215 | - $this->manager = $this->createManagerInstance(); // Create manager with configured mocks |
|
| 1216 | - |
|
| 1217 | - // Act |
|
| 1218 | - $providers = $this->manager->getProviders(); // Returns ID-indexed array |
|
| 1219 | - |
|
| 1220 | - // Assert |
|
| 1221 | - self::assertArrayHasKey(ExternalProvider::ID, $providers); |
|
| 1222 | - self::assertInstanceOf(ExternalProvider::class, $providers[ExternalProvider::ID]); |
|
| 1223 | - self::assertCount(1, $providers); |
|
| 1224 | - self::assertTrue($this->manager->hasProviders()); |
|
| 1225 | - } |
|
| 1226 | - |
|
| 1227 | - public function testGetAvailableTaskTypesIncludesExternalViaEvent() { |
|
| 1228 | - // Arrange: No local types/providers, one external type and provider via event |
|
| 1229 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1230 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([]); |
|
| 1231 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1232 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1233 | - $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1234 | - |
|
| 1235 | - $externalProvider = new ExternalProvider(); // Provides ExternalTaskType |
|
| 1236 | - $externalTaskType = new ExternalTaskType(); |
|
| 1237 | - $this->configureEventDispatcherMock( |
|
| 1238 | - providersToAdd: [$externalProvider], |
|
| 1239 | - taskTypesToAdd: [$externalTaskType] |
|
| 1240 | - ); |
|
| 1241 | - $this->manager = $this->createManagerInstance(); |
|
| 1242 | - |
|
| 1243 | - // Act |
|
| 1244 | - $availableTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1245 | - |
|
| 1246 | - // Assert |
|
| 1247 | - self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes); |
|
| 1248 | - self::assertContains(ExternalTaskType::ID, $this->manager->getAvailableTaskTypeIds()); |
|
| 1249 | - self::assertEquals(ExternalTaskType::ID, $externalProvider->getTaskTypeId(), 'Test Sanity: Provider must handle the Task Type'); |
|
| 1250 | - self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']); |
|
| 1251 | - // Check if shapes match the external type/provider |
|
| 1252 | - self::assertArrayHasKey('external_input', $availableTypes[ExternalTaskType::ID]['inputShape']); |
|
| 1253 | - self::assertArrayHasKey('external_output', $availableTypes[ExternalTaskType::ID]['outputShape']); |
|
| 1254 | - self::assertEmpty($availableTypes[ExternalTaskType::ID]['optionalInputShape']); // From ExternalProvider |
|
| 1255 | - } |
|
| 1256 | - |
|
| 1257 | - public function testLocalProviderWinsConflictWithEvent() { |
|
| 1258 | - // Arrange: Local provider registered, conflicting external provider via event |
|
| 1259 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1260 | - new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 1261 | - ]); |
|
| 1262 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1263 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1264 | - $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1265 | - |
|
| 1266 | - $conflictingExternalProvider = new ConflictingExternalProvider(); // ID = 'test:sync:success' |
|
| 1267 | - $this->configureEventDispatcherMock(providersToAdd: [$conflictingExternalProvider]); |
|
| 1268 | - $this->manager = $this->createManagerInstance(); |
|
| 1269 | - |
|
| 1270 | - // Act |
|
| 1271 | - $providers = $this->manager->getProviders(); |
|
| 1272 | - |
|
| 1273 | - // Assert: Only the local provider should be present for the conflicting ID |
|
| 1274 | - self::assertArrayHasKey(SuccessfulSyncProvider::ID, $providers); |
|
| 1275 | - self::assertInstanceOf(SuccessfulSyncProvider::class, $providers[SuccessfulSyncProvider::ID]); |
|
| 1276 | - self::assertCount(1, $providers); // Ensure no extra provider was added |
|
| 1277 | - } |
|
| 1278 | - |
|
| 1279 | - public function testTriggerableProviderWithNoOtherRunningTasks() { |
|
| 1280 | - // Arrange: Local provider registered, conflicting external provider via event |
|
| 1281 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1282 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1283 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1284 | - $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1285 | - |
|
| 1286 | - $externalProvider = $this->createPartialMock(ExternalTriggerableProvider::class, ['trigger']); |
|
| 1287 | - $externalProvider->expects($this->once())->method('trigger'); |
|
| 1288 | - $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1289 | - $this->manager = $this->createManagerInstance(); |
|
| 1290 | - |
|
| 1291 | - // Act |
|
| 1292 | - $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar'); |
|
| 1293 | - $this->manager->scheduleTask($task); |
|
| 1294 | - } |
|
| 1295 | - |
|
| 1296 | - public function testTriggerableProviderWithOtherRunningTasks() { |
|
| 1297 | - // Arrange: Local provider registered, conflicting external provider via event |
|
| 1298 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1299 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1300 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1301 | - $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1302 | - |
|
| 1303 | - $externalProvider = $this->createPartialMock(ExternalTriggerableProvider::class, ['trigger']); |
|
| 1304 | - $externalProvider->expects($this->once())->method('trigger'); |
|
| 1305 | - $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1306 | - $this->manager = $this->createManagerInstance(); |
|
| 1307 | - |
|
| 1308 | - $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar'); |
|
| 1309 | - $this->manager->scheduleTask($task); |
|
| 1310 | - $this->manager->lockTask($task); |
|
| 1311 | - |
|
| 1312 | - // Act |
|
| 1313 | - $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar'); |
|
| 1314 | - $this->manager->scheduleTask($task); |
|
| 1315 | - } |
|
| 1316 | - |
|
| 1317 | - public function testMergeTaskTypesLocalAndEvent() { |
|
| 1318 | - // Arrange: Local type registered, DIFFERENT external type via event |
|
| 1319 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1320 | - new ServiceRegistration('test', AsyncProvider::class) |
|
| 1321 | - ]); |
|
| 1322 | - $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 1323 | - new ServiceRegistration('test', AudioToImage::class) |
|
| 1324 | - ]); |
|
| 1325 | - $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1326 | - $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1327 | - $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1328 | - |
|
| 1329 | - $externalTaskType = new ExternalTaskType(); // ID = 'event:external:tasktype' |
|
| 1330 | - $externalProvider = new ExternalProvider(); // Handles 'event:external:tasktype' |
|
| 1331 | - $this->configureEventDispatcherMock( |
|
| 1332 | - providersToAdd: [$externalProvider], |
|
| 1333 | - taskTypesToAdd: [$externalTaskType] |
|
| 1334 | - ); |
|
| 1335 | - $this->manager = $this->createManagerInstance(); |
|
| 1336 | - |
|
| 1337 | - // Act |
|
| 1338 | - $availableTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1339 | - $availableTypeIds = $this->manager->getAvailableTaskTypeIds(); |
|
| 1340 | - |
|
| 1341 | - // Assert: Both task types should be available |
|
| 1342 | - self::assertContains(AudioToImage::ID, $availableTypeIds); |
|
| 1343 | - self::assertArrayHasKey(AudioToImage::ID, $availableTypes); |
|
| 1344 | - self::assertEquals(AudioToImage::class, $availableTypes[AudioToImage::ID]['name']); |
|
| 1345 | - |
|
| 1346 | - self::assertContains(ExternalTaskType::ID, $availableTypeIds); |
|
| 1347 | - self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes); |
|
| 1348 | - self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']); |
|
| 1349 | - |
|
| 1350 | - self::assertCount(2, $availableTypes); |
|
| 1351 | - } |
|
| 1352 | - |
|
| 1353 | - private function createManagerInstance(): Manager { |
|
| 1354 | - // Clear potentially cached config values if needed |
|
| 1355 | - $this->appConfig->deleteKey('core', 'ai.taskprocessing_type_preferences'); |
|
| 1356 | - |
|
| 1357 | - // Re-create Text2ImageManager if its state matters or mocks change |
|
| 1358 | - $text2imageManager = new \OC\TextToImage\Manager( |
|
| 1359 | - $this->serverContainer, |
|
| 1360 | - $this->coordinator, |
|
| 1361 | - Server::get(LoggerInterface::class), |
|
| 1362 | - $this->jobList, |
|
| 1363 | - Server::get(\OC\TextToImage\Db\TaskMapper::class), |
|
| 1364 | - $this->config, // Use the shared config mock |
|
| 1365 | - Server::get(IAppDataFactory::class), |
|
| 1366 | - ); |
|
| 1367 | - |
|
| 1368 | - return new Manager( |
|
| 1369 | - $this->appConfig, |
|
| 1370 | - $this->coordinator, |
|
| 1371 | - $this->serverContainer, |
|
| 1372 | - Server::get(LoggerInterface::class), |
|
| 1373 | - $this->taskMapper, |
|
| 1374 | - $this->jobList, |
|
| 1375 | - $this->eventDispatcher, // Use the potentially reconfigured mock |
|
| 1376 | - Server::get(IAppDataFactory::class), |
|
| 1377 | - $this->rootFolder, |
|
| 1378 | - $text2imageManager, |
|
| 1379 | - $this->userMountCache, |
|
| 1380 | - Server::get(IClientService::class), |
|
| 1381 | - Server::get(IAppManager::class), |
|
| 1382 | - Server::get(IUserManager::class), |
|
| 1383 | - Server::get(IUserSession::class), |
|
| 1384 | - Server::get(ICacheFactory::class), |
|
| 1385 | - Server::get(IFactory::class), |
|
| 1386 | - ); |
|
| 1387 | - } |
|
| 1388 | - |
|
| 1389 | - private function configureEventDispatcherMock( |
|
| 1390 | - array $providersToAdd = [], |
|
| 1391 | - array $taskTypesToAdd = [], |
|
| 1392 | - ?int $expectedCalls = null, |
|
| 1393 | - ): void { |
|
| 1394 | - $dispatchExpectation = $expectedCalls === null ? $this->any() : $this->exactly($expectedCalls); |
|
| 1395 | - |
|
| 1396 | - $this->eventDispatcher->expects($dispatchExpectation) |
|
| 1397 | - ->method('dispatchTyped') |
|
| 1398 | - ->willReturnCallback(function (object $event) use ($providersToAdd, $taskTypesToAdd): void { |
|
| 1399 | - if ($event instanceof GetTaskProcessingProvidersEvent) { |
|
| 1400 | - foreach ($providersToAdd as $providerInstance) { |
|
| 1401 | - $event->addProvider($providerInstance); |
|
| 1402 | - } |
|
| 1403 | - foreach ($taskTypesToAdd as $taskTypeInstance) { |
|
| 1404 | - $event->addTaskType($taskTypeInstance); |
|
| 1405 | - } |
|
| 1406 | - } |
|
| 1407 | - }); |
|
| 1408 | - } |
|
| 577 | + private IManager $manager; |
|
| 578 | + private Coordinator $coordinator; |
|
| 579 | + private array $providers; |
|
| 580 | + private IServerContainer $serverContainer; |
|
| 581 | + private IEventDispatcher $eventDispatcher; |
|
| 582 | + private RegistrationContext $registrationContext; |
|
| 583 | + private TaskMapper $taskMapper; |
|
| 584 | + private IJobList $jobList; |
|
| 585 | + private IUserMountCache $userMountCache; |
|
| 586 | + private IRootFolder $rootFolder; |
|
| 587 | + private IConfig $config; |
|
| 588 | + private IAppConfig $appConfig; |
|
| 589 | + |
|
| 590 | + public const TEST_USER = 'testuser'; |
|
| 591 | + |
|
| 592 | + protected function setUp(): void { |
|
| 593 | + parent::setUp(); |
|
| 594 | + |
|
| 595 | + $this->providers = [ |
|
| 596 | + SuccessfulSyncProvider::class => new SuccessfulSyncProvider(), |
|
| 597 | + FailingSyncProvider::class => new FailingSyncProvider(), |
|
| 598 | + BrokenSyncProvider::class => new BrokenSyncProvider(), |
|
| 599 | + AsyncProvider::class => new AsyncProvider(), |
|
| 600 | + AudioToImage::class => new AudioToImage(), |
|
| 601 | + SuccessfulTextProcessingSummaryProvider::class => new SuccessfulTextProcessingSummaryProvider(), |
|
| 602 | + FailingTextProcessingSummaryProvider::class => new FailingTextProcessingSummaryProvider(), |
|
| 603 | + SuccessfulTextToImageProvider::class => new SuccessfulTextToImageProvider(), |
|
| 604 | + FailingTextToImageProvider::class => new FailingTextToImageProvider(), |
|
| 605 | + ExternalProvider::class => new ExternalProvider(), |
|
| 606 | + ExternalTriggerableProvider::class => new ExternalTriggerableProvider(), |
|
| 607 | + ConflictingExternalProvider::class => new ConflictingExternalProvider(), |
|
| 608 | + ExternalTaskType::class => new ExternalTaskType(), |
|
| 609 | + ConflictingExternalTaskType::class => new ConflictingExternalTaskType(), |
|
| 610 | + ]; |
|
| 611 | + |
|
| 612 | + $userManager = Server::get(IUserManager::class); |
|
| 613 | + if (!$userManager->userExists(self::TEST_USER)) { |
|
| 614 | + $userManager->createUser(self::TEST_USER, 'test'); |
|
| 615 | + } |
|
| 616 | + |
|
| 617 | + $this->serverContainer = $this->createMock(IServerContainer::class); |
|
| 618 | + $this->serverContainer->expects($this->any())->method('get')->willReturnCallback(function ($class) { |
|
| 619 | + return $this->providers[$class]; |
|
| 620 | + }); |
|
| 621 | + |
|
| 622 | + $this->eventDispatcher = new EventDispatcher( |
|
| 623 | + new \Symfony\Component\EventDispatcher\EventDispatcher(), |
|
| 624 | + $this->serverContainer, |
|
| 625 | + Server::get(LoggerInterface::class), |
|
| 626 | + ); |
|
| 627 | + |
|
| 628 | + $this->registrationContext = $this->createMock(RegistrationContext::class); |
|
| 629 | + $this->coordinator = $this->createMock(Coordinator::class); |
|
| 630 | + $this->coordinator->expects($this->any())->method('getRegistrationContext')->willReturn($this->registrationContext); |
|
| 631 | + |
|
| 632 | + $this->rootFolder = Server::get(IRootFolder::class); |
|
| 633 | + |
|
| 634 | + $this->taskMapper = Server::get(TaskMapper::class); |
|
| 635 | + |
|
| 636 | + $this->jobList = $this->createPartialMock(DummyJobList::class, ['add']); |
|
| 637 | + $this->jobList->expects($this->any())->method('add')->willReturnCallback(function (): void { |
|
| 638 | + }); |
|
| 639 | + |
|
| 640 | + $this->eventDispatcher = $this->createMock(IEventDispatcher::class); |
|
| 641 | + $this->configureEventDispatcherMock(); |
|
| 642 | + |
|
| 643 | + $text2imageManager = new \OC\TextToImage\Manager( |
|
| 644 | + $this->serverContainer, |
|
| 645 | + $this->coordinator, |
|
| 646 | + Server::get(LoggerInterface::class), |
|
| 647 | + $this->jobList, |
|
| 648 | + Server::get(\OC\TextToImage\Db\TaskMapper::class), |
|
| 649 | + Server::get(IConfig::class), |
|
| 650 | + Server::get(IAppDataFactory::class), |
|
| 651 | + ); |
|
| 652 | + |
|
| 653 | + $this->userMountCache = $this->createMock(IUserMountCache::class); |
|
| 654 | + $this->config = Server::get(IConfig::class); |
|
| 655 | + $this->appConfig = Server::get(IAppConfig::class); |
|
| 656 | + $this->manager = new Manager( |
|
| 657 | + $this->appConfig, |
|
| 658 | + $this->coordinator, |
|
| 659 | + $this->serverContainer, |
|
| 660 | + Server::get(LoggerInterface::class), |
|
| 661 | + $this->taskMapper, |
|
| 662 | + $this->jobList, |
|
| 663 | + $this->eventDispatcher, |
|
| 664 | + Server::get(IAppDataFactory::class), |
|
| 665 | + Server::get(IRootFolder::class), |
|
| 666 | + $text2imageManager, |
|
| 667 | + $this->userMountCache, |
|
| 668 | + Server::get(IClientService::class), |
|
| 669 | + Server::get(IAppManager::class), |
|
| 670 | + $userManager, |
|
| 671 | + Server::get(IUserSession::class), |
|
| 672 | + Server::get(ICacheFactory::class), |
|
| 673 | + Server::get(IFactory::class), |
|
| 674 | + ); |
|
| 675 | + } |
|
| 676 | + |
|
| 677 | + private function getFile(string $name, string $content): File { |
|
| 678 | + $folder = $this->rootFolder->getUserFolder(self::TEST_USER); |
|
| 679 | + $file = $folder->newFile($name, $content); |
|
| 680 | + return $file; |
|
| 681 | + } |
|
| 682 | + |
|
| 683 | + public function testShouldNotHaveAnyProviders(): void { |
|
| 684 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 685 | + self::assertCount(0, $this->manager->getAvailableTaskTypes()); |
|
| 686 | + self::assertCount(0, $this->manager->getAvailableTaskTypeIds()); |
|
| 687 | + self::assertFalse($this->manager->hasProviders()); |
|
| 688 | + self::expectException(PreConditionNotMetException::class); |
|
| 689 | + $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null)); |
|
| 690 | + } |
|
| 691 | + |
|
| 692 | + public function testProviderShouldBeRegisteredAndTaskTypeDisabled(): void { |
|
| 693 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 694 | + new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 695 | + ]); |
|
| 696 | + $taskProcessingTypeSettings = [ |
|
| 697 | + TextToText::ID => false, |
|
| 698 | + ]; |
|
| 699 | + $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true); |
|
| 700 | + self::assertCount(0, $this->manager->getAvailableTaskTypes()); |
|
| 701 | + self::assertCount(1, $this->manager->getAvailableTaskTypes(true)); |
|
| 702 | + self::assertCount(0, $this->manager->getAvailableTaskTypeIds()); |
|
| 703 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds(true)); |
|
| 704 | + self::assertTrue($this->manager->hasProviders()); |
|
| 705 | + self::expectException(PreConditionNotMetException::class); |
|
| 706 | + $this->manager->scheduleTask(new Task(TextToText::ID, ['input' => 'Hello'], 'test', null)); |
|
| 707 | + } |
|
| 708 | + |
|
| 709 | + |
|
| 710 | + public function testProviderShouldBeRegisteredAndTaskFailValidation(): void { |
|
| 711 | + $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', '', lazy: true); |
|
| 712 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 713 | + new ServiceRegistration('test', BrokenSyncProvider::class) |
|
| 714 | + ]); |
|
| 715 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 716 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 717 | + self::assertTrue($this->manager->hasProviders()); |
|
| 718 | + $task = new Task(TextToText::ID, ['wrongInputKey' => 'Hello'], 'test', null); |
|
| 719 | + self::assertNull($task->getId()); |
|
| 720 | + self::expectException(ValidationException::class); |
|
| 721 | + $this->manager->scheduleTask($task); |
|
| 722 | + } |
|
| 723 | + |
|
| 724 | + public function testProviderShouldBeRegisteredAndTaskWithFilesFailValidation(): void { |
|
| 725 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 726 | + new ServiceRegistration('test', AudioToImage::class) |
|
| 727 | + ]); |
|
| 728 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 729 | + new ServiceRegistration('test', AsyncProvider::class) |
|
| 730 | + ]); |
|
| 731 | + $user = $this->createMock(IUser::class); |
|
| 732 | + $user->expects($this->any())->method('getUID')->willReturn(null); |
|
| 733 | + $mount = $this->createMock(ICachedMountInfo::class); |
|
| 734 | + $mount->expects($this->any())->method('getUser')->willReturn($user); |
|
| 735 | + $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); |
|
| 736 | + |
|
| 737 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 738 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 739 | + self::assertTrue($this->manager->hasProviders()); |
|
| 740 | + |
|
| 741 | + $audioId = $this->getFile('audioInput', 'Hello')->getId(); |
|
| 742 | + $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', null); |
|
| 743 | + self::assertNull($task->getId()); |
|
| 744 | + self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 745 | + self::expectException(UnauthorizedException::class); |
|
| 746 | + $this->manager->scheduleTask($task); |
|
| 747 | + } |
|
| 748 | + |
|
| 749 | + public function testProviderShouldBeRegisteredAndFail(): void { |
|
| 750 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 751 | + new ServiceRegistration('test', FailingSyncProvider::class) |
|
| 752 | + ]); |
|
| 753 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 754 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 755 | + self::assertTrue($this->manager->hasProviders()); |
|
| 756 | + $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 757 | + self::assertNull($task->getId()); |
|
| 758 | + self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 759 | + $this->manager->scheduleTask($task); |
|
| 760 | + self::assertNotNull($task->getId()); |
|
| 761 | + self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 762 | + |
|
| 763 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 764 | + |
|
| 765 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 766 | + Server::get(ITimeFactory::class), |
|
| 767 | + $this->manager, |
|
| 768 | + $this->jobList, |
|
| 769 | + Server::get(LoggerInterface::class), |
|
| 770 | + ); |
|
| 771 | + $backgroundJob->start($this->jobList); |
|
| 772 | + |
|
| 773 | + $task = $this->manager->getTask($task->getId()); |
|
| 774 | + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 775 | + self::assertEquals(FailingSyncProvider::ERROR_MESSAGE, $task->getErrorMessage()); |
|
| 776 | + } |
|
| 777 | + |
|
| 778 | + public function testProviderShouldBeRegisteredAndFailOutputValidation(): void { |
|
| 779 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 780 | + new ServiceRegistration('test', BrokenSyncProvider::class) |
|
| 781 | + ]); |
|
| 782 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 783 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 784 | + self::assertTrue($this->manager->hasProviders()); |
|
| 785 | + $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 786 | + self::assertNull($task->getId()); |
|
| 787 | + self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 788 | + $this->manager->scheduleTask($task); |
|
| 789 | + self::assertNotNull($task->getId()); |
|
| 790 | + self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 791 | + |
|
| 792 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 793 | + |
|
| 794 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 795 | + Server::get(ITimeFactory::class), |
|
| 796 | + $this->manager, |
|
| 797 | + $this->jobList, |
|
| 798 | + Server::get(LoggerInterface::class), |
|
| 799 | + ); |
|
| 800 | + $backgroundJob->start($this->jobList); |
|
| 801 | + |
|
| 802 | + $task = $this->manager->getTask($task->getId()); |
|
| 803 | + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 804 | + self::assertEquals('The task was processed successfully but the provider\'s output doesn\'t pass validation against the task type\'s outputShape spec and/or the provider\'s own optionalOutputShape spec', $task->getErrorMessage()); |
|
| 805 | + } |
|
| 806 | + |
|
| 807 | + public function testProviderShouldBeRegisteredAndRun(): void { |
|
| 808 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 809 | + new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 810 | + ]); |
|
| 811 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 812 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 813 | + $taskTypeStruct = $this->manager->getAvailableTaskTypes()[array_keys($this->manager->getAvailableTaskTypes())[0]]; |
|
| 814 | + self::assertTrue(isset($taskTypeStruct['inputShape']['input'])); |
|
| 815 | + self::assertEquals(EShapeType::Text, $taskTypeStruct['inputShape']['input']->getShapeType()); |
|
| 816 | + self::assertTrue(isset($taskTypeStruct['optionalInputShape']['optionalKey'])); |
|
| 817 | + self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalInputShape']['optionalKey']->getShapeType()); |
|
| 818 | + self::assertTrue(isset($taskTypeStruct['outputShape']['output'])); |
|
| 819 | + self::assertEquals(EShapeType::Text, $taskTypeStruct['outputShape']['output']->getShapeType()); |
|
| 820 | + self::assertTrue(isset($taskTypeStruct['optionalOutputShape']['optionalKey'])); |
|
| 821 | + self::assertEquals(EShapeType::Text, $taskTypeStruct['optionalOutputShape']['optionalKey']->getShapeType()); |
|
| 822 | + |
|
| 823 | + self::assertTrue($this->manager->hasProviders()); |
|
| 824 | + $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 825 | + self::assertNull($task->getId()); |
|
| 826 | + self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 827 | + $this->manager->scheduleTask($task); |
|
| 828 | + self::assertNotNull($task->getId()); |
|
| 829 | + self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 830 | + |
|
| 831 | + // Task object retrieved from db is up-to-date |
|
| 832 | + $task2 = $this->manager->getTask($task->getId()); |
|
| 833 | + self::assertEquals($task->getId(), $task2->getId()); |
|
| 834 | + self::assertEquals(['input' => 'Hello'], $task2->getInput()); |
|
| 835 | + self::assertNull($task2->getOutput()); |
|
| 836 | + self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); |
|
| 837 | + |
|
| 838 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 839 | + |
|
| 840 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 841 | + Server::get(ITimeFactory::class), |
|
| 842 | + $this->manager, |
|
| 843 | + $this->jobList, |
|
| 844 | + Server::get(LoggerInterface::class), |
|
| 845 | + ); |
|
| 846 | + $backgroundJob->start($this->jobList); |
|
| 847 | + |
|
| 848 | + $task = $this->manager->getTask($task->getId()); |
|
| 849 | + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is ' . $task->getStatus() . ' with error message: ' . $task->getErrorMessage()); |
|
| 850 | + self::assertEquals(['output' => 'Hello'], $task->getOutput()); |
|
| 851 | + self::assertEquals(1, $task->getProgress()); |
|
| 852 | + } |
|
| 853 | + |
|
| 854 | + public function testTaskTypeExplicitlyEnabled(): void { |
|
| 855 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 856 | + new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 857 | + ]); |
|
| 858 | + |
|
| 859 | + $taskProcessingTypeSettings = [ |
|
| 860 | + TextToText::ID => true, |
|
| 861 | + ]; |
|
| 862 | + $this->appConfig->setValueString('core', 'ai.taskprocessing_type_preferences', json_encode($taskProcessingTypeSettings), lazy: true); |
|
| 863 | + |
|
| 864 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 865 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 866 | + |
|
| 867 | + self::assertTrue($this->manager->hasProviders()); |
|
| 868 | + $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 869 | + self::assertNull($task->getId()); |
|
| 870 | + self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 871 | + $this->manager->scheduleTask($task); |
|
| 872 | + self::assertNotNull($task->getId()); |
|
| 873 | + self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 874 | + |
|
| 875 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 876 | + |
|
| 877 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 878 | + Server::get(ITimeFactory::class), |
|
| 879 | + $this->manager, |
|
| 880 | + $this->jobList, |
|
| 881 | + Server::get(LoggerInterface::class), |
|
| 882 | + ); |
|
| 883 | + $backgroundJob->start($this->jobList); |
|
| 884 | + |
|
| 885 | + $task = $this->manager->getTask($task->getId()); |
|
| 886 | + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus(), 'Status is ' . $task->getStatus() . ' with error message: ' . $task->getErrorMessage()); |
|
| 887 | + self::assertEquals(['output' => 'Hello'], $task->getOutput()); |
|
| 888 | + self::assertEquals(1, $task->getProgress()); |
|
| 889 | + } |
|
| 890 | + |
|
| 891 | + public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningRawFileData(): void { |
|
| 892 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 893 | + new ServiceRegistration('test', AudioToImage::class) |
|
| 894 | + ]); |
|
| 895 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 896 | + new ServiceRegistration('test', AsyncProvider::class) |
|
| 897 | + ]); |
|
| 898 | + |
|
| 899 | + $user = $this->createMock(IUser::class); |
|
| 900 | + $user->expects($this->any())->method('getUID')->willReturn('testuser'); |
|
| 901 | + $mount = $this->createMock(ICachedMountInfo::class); |
|
| 902 | + $mount->expects($this->any())->method('getUser')->willReturn($user); |
|
| 903 | + $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); |
|
| 904 | + |
|
| 905 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 906 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 907 | + |
|
| 908 | + self::assertTrue($this->manager->hasProviders()); |
|
| 909 | + $audioId = $this->getFile('audioInput', 'Hello')->getId(); |
|
| 910 | + $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser'); |
|
| 911 | + self::assertNull($task->getId()); |
|
| 912 | + self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 913 | + $this->manager->scheduleTask($task); |
|
| 914 | + self::assertNotNull($task->getId()); |
|
| 915 | + self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 916 | + |
|
| 917 | + // Task object retrieved from db is up-to-date |
|
| 918 | + $task2 = $this->manager->getTask($task->getId()); |
|
| 919 | + self::assertEquals($task->getId(), $task2->getId()); |
|
| 920 | + self::assertEquals(['audio' => $audioId], $task2->getInput()); |
|
| 921 | + self::assertNull($task2->getOutput()); |
|
| 922 | + self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); |
|
| 923 | + |
|
| 924 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 925 | + |
|
| 926 | + $this->manager->setTaskProgress($task2->getId(), 0.1); |
|
| 927 | + $input = $this->manager->prepareInputData($task2); |
|
| 928 | + self::assertTrue(isset($input['audio'])); |
|
| 929 | + self::assertInstanceOf(File::class, $input['audio']); |
|
| 930 | + self::assertEquals($audioId, $input['audio']->getId()); |
|
| 931 | + |
|
| 932 | + $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => 'World']); |
|
| 933 | + |
|
| 934 | + $task = $this->manager->getTask($task->getId()); |
|
| 935 | + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 936 | + self::assertEquals(1, $task->getProgress()); |
|
| 937 | + self::assertTrue(isset($task->getOutput()['spectrogram'])); |
|
| 938 | + $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['spectrogram'], '/' . $this->rootFolder->getAppDataDirectoryName() . '/'); |
|
| 939 | + self::assertNotNull($node); |
|
| 940 | + self::assertInstanceOf(File::class, $node); |
|
| 941 | + self::assertEquals('World', $node->getContent()); |
|
| 942 | + } |
|
| 943 | + |
|
| 944 | + public function testAsyncProviderWithFilesShouldBeRegisteredAndRunReturningFileIds(): void { |
|
| 945 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 946 | + new ServiceRegistration('test', AudioToImage::class) |
|
| 947 | + ]); |
|
| 948 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 949 | + new ServiceRegistration('test', AsyncProvider::class) |
|
| 950 | + ]); |
|
| 951 | + $user = $this->createMock(IUser::class); |
|
| 952 | + $user->expects($this->any())->method('getUID')->willReturn('testuser'); |
|
| 953 | + $mount = $this->createMock(ICachedMountInfo::class); |
|
| 954 | + $mount->expects($this->any())->method('getUser')->willReturn($user); |
|
| 955 | + $this->userMountCache->expects($this->any())->method('getMountsForFileId')->willReturn([$mount]); |
|
| 956 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 957 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 958 | + |
|
| 959 | + self::assertTrue($this->manager->hasProviders()); |
|
| 960 | + $audioId = $this->getFile('audioInput', 'Hello')->getId(); |
|
| 961 | + $task = new Task(AudioToImage::ID, ['audio' => $audioId], 'test', 'testuser'); |
|
| 962 | + self::assertNull($task->getId()); |
|
| 963 | + self::assertEquals(Task::STATUS_UNKNOWN, $task->getStatus()); |
|
| 964 | + $this->manager->scheduleTask($task); |
|
| 965 | + self::assertNotNull($task->getId()); |
|
| 966 | + self::assertEquals(Task::STATUS_SCHEDULED, $task->getStatus()); |
|
| 967 | + |
|
| 968 | + // Task object retrieved from db is up-to-date |
|
| 969 | + $task2 = $this->manager->getTask($task->getId()); |
|
| 970 | + self::assertEquals($task->getId(), $task2->getId()); |
|
| 971 | + self::assertEquals(['audio' => $audioId], $task2->getInput()); |
|
| 972 | + self::assertNull($task2->getOutput()); |
|
| 973 | + self::assertEquals(Task::STATUS_SCHEDULED, $task2->getStatus()); |
|
| 974 | + |
|
| 975 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 976 | + |
|
| 977 | + $this->manager->setTaskProgress($task2->getId(), 0.1); |
|
| 978 | + $input = $this->manager->prepareInputData($task2); |
|
| 979 | + self::assertTrue(isset($input['audio'])); |
|
| 980 | + self::assertInstanceOf(File::class, $input['audio']); |
|
| 981 | + self::assertEquals($audioId, $input['audio']->getId()); |
|
| 982 | + |
|
| 983 | + $outputFileId = $this->getFile('audioOutput', 'World')->getId(); |
|
| 984 | + |
|
| 985 | + $this->manager->setTaskResult($task2->getId(), null, ['spectrogram' => $outputFileId], true); |
|
| 986 | + |
|
| 987 | + $task = $this->manager->getTask($task->getId()); |
|
| 988 | + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 989 | + self::assertEquals(1, $task->getProgress()); |
|
| 990 | + self::assertTrue(isset($task->getOutput()['spectrogram'])); |
|
| 991 | + $node = $this->rootFolder->getFirstNodeById($task->getOutput()['spectrogram']); |
|
| 992 | + self::assertNotNull($node, 'fileId:' . $task->getOutput()['spectrogram']); |
|
| 993 | + self::assertInstanceOf(File::class, $node); |
|
| 994 | + self::assertEquals('World', $node->getContent()); |
|
| 995 | + } |
|
| 996 | + |
|
| 997 | + public function testNonexistentTask(): void { |
|
| 998 | + $this->expectException(NotFoundException::class); |
|
| 999 | + $this->manager->getTask(2147483646); |
|
| 1000 | + } |
|
| 1001 | + |
|
| 1002 | + public function testOldTasksShouldBeCleanedUp(): void { |
|
| 1003 | + $currentTime = new \DateTime('now'); |
|
| 1004 | + $timeFactory = $this->createMock(ITimeFactory::class); |
|
| 1005 | + $timeFactory->expects($this->any())->method('getDateTime')->willReturnCallback(fn () => $currentTime); |
|
| 1006 | + $timeFactory->expects($this->any())->method('getTime')->willReturnCallback(fn () => $currentTime->getTimestamp()); |
|
| 1007 | + |
|
| 1008 | + $this->taskMapper = new TaskMapper( |
|
| 1009 | + Server::get(IDBConnection::class), |
|
| 1010 | + $timeFactory, |
|
| 1011 | + ); |
|
| 1012 | + |
|
| 1013 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1014 | + new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 1015 | + ]); |
|
| 1016 | + self::assertCount(1, $this->manager->getAvailableTaskTypes()); |
|
| 1017 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1018 | + self::assertTrue($this->manager->hasProviders()); |
|
| 1019 | + $task = new Task(TextToText::ID, ['input' => 'Hello'], 'test', null); |
|
| 1020 | + $this->manager->scheduleTask($task); |
|
| 1021 | + |
|
| 1022 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 1023 | + |
|
| 1024 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 1025 | + Server::get(ITimeFactory::class), |
|
| 1026 | + $this->manager, |
|
| 1027 | + $this->jobList, |
|
| 1028 | + Server::get(LoggerInterface::class), |
|
| 1029 | + ); |
|
| 1030 | + $backgroundJob->start($this->jobList); |
|
| 1031 | + |
|
| 1032 | + $task = $this->manager->getTask($task->getId()); |
|
| 1033 | + |
|
| 1034 | + $currentTime = $currentTime->add(new \DateInterval('P1Y')); |
|
| 1035 | + // run background job |
|
| 1036 | + $bgJob = new RemoveOldTasksBackgroundJob( |
|
| 1037 | + $timeFactory, |
|
| 1038 | + $this->manager, |
|
| 1039 | + $this->taskMapper, |
|
| 1040 | + Server::get(LoggerInterface::class), |
|
| 1041 | + Server::get(IAppDataFactory::class), |
|
| 1042 | + ); |
|
| 1043 | + $bgJob->setArgument([]); |
|
| 1044 | + $bgJob->start($this->jobList); |
|
| 1045 | + |
|
| 1046 | + $this->expectException(NotFoundException::class); |
|
| 1047 | + $this->manager->getTask($task->getId()); |
|
| 1048 | + } |
|
| 1049 | + |
|
| 1050 | + public function testShouldTransparentlyHandleTextProcessingProviders(): void { |
|
| 1051 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ |
|
| 1052 | + new ServiceRegistration('test', SuccessfulTextProcessingSummaryProvider::class) |
|
| 1053 | + ]); |
|
| 1054 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1055 | + ]); |
|
| 1056 | + $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1057 | + self::assertCount(1, $taskTypes); |
|
| 1058 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1059 | + self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); |
|
| 1060 | + self::assertTrue($this->manager->hasProviders()); |
|
| 1061 | + $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); |
|
| 1062 | + $this->manager->scheduleTask($task); |
|
| 1063 | + |
|
| 1064 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 1065 | + |
|
| 1066 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 1067 | + Server::get(ITimeFactory::class), |
|
| 1068 | + $this->manager, |
|
| 1069 | + $this->jobList, |
|
| 1070 | + Server::get(LoggerInterface::class), |
|
| 1071 | + ); |
|
| 1072 | + $backgroundJob->start($this->jobList); |
|
| 1073 | + |
|
| 1074 | + $task = $this->manager->getTask($task->getId()); |
|
| 1075 | + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 1076 | + self::assertIsArray($task->getOutput()); |
|
| 1077 | + self::assertTrue(isset($task->getOutput()['output'])); |
|
| 1078 | + self::assertEquals('Hello Summarize', $task->getOutput()['output']); |
|
| 1079 | + self::assertTrue($this->providers[SuccessfulTextProcessingSummaryProvider::class]->ran); |
|
| 1080 | + } |
|
| 1081 | + |
|
| 1082 | + public function testShouldTransparentlyHandleFailingTextProcessingProviders(): void { |
|
| 1083 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([ |
|
| 1084 | + new ServiceRegistration('test', FailingTextProcessingSummaryProvider::class) |
|
| 1085 | + ]); |
|
| 1086 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1087 | + ]); |
|
| 1088 | + $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1089 | + self::assertCount(1, $taskTypes); |
|
| 1090 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1091 | + self::assertTrue(isset($taskTypes[TextToTextSummary::ID])); |
|
| 1092 | + self::assertTrue($this->manager->hasProviders()); |
|
| 1093 | + $task = new Task(TextToTextSummary::ID, ['input' => 'Hello'], 'test', null); |
|
| 1094 | + $this->manager->scheduleTask($task); |
|
| 1095 | + |
|
| 1096 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 1097 | + |
|
| 1098 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 1099 | + Server::get(ITimeFactory::class), |
|
| 1100 | + $this->manager, |
|
| 1101 | + $this->jobList, |
|
| 1102 | + Server::get(LoggerInterface::class), |
|
| 1103 | + ); |
|
| 1104 | + $backgroundJob->start($this->jobList); |
|
| 1105 | + |
|
| 1106 | + $task = $this->manager->getTask($task->getId()); |
|
| 1107 | + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 1108 | + self::assertTrue($task->getOutput() === null); |
|
| 1109 | + self::assertEquals('ERROR', $task->getErrorMessage()); |
|
| 1110 | + self::assertTrue($this->providers[FailingTextProcessingSummaryProvider::class]->ran); |
|
| 1111 | + } |
|
| 1112 | + |
|
| 1113 | + public function testShouldTransparentlyHandleText2ImageProviders(): void { |
|
| 1114 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ |
|
| 1115 | + new ServiceRegistration('test', SuccessfulTextToImageProvider::class) |
|
| 1116 | + ]); |
|
| 1117 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1118 | + ]); |
|
| 1119 | + $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1120 | + self::assertCount(1, $taskTypes); |
|
| 1121 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1122 | + self::assertTrue(isset($taskTypes[TextToImage::ID])); |
|
| 1123 | + self::assertTrue($this->manager->hasProviders()); |
|
| 1124 | + $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); |
|
| 1125 | + $this->manager->scheduleTask($task); |
|
| 1126 | + |
|
| 1127 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskSuccessfulEvent::class)); |
|
| 1128 | + |
|
| 1129 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 1130 | + Server::get(ITimeFactory::class), |
|
| 1131 | + $this->manager, |
|
| 1132 | + $this->jobList, |
|
| 1133 | + Server::get(LoggerInterface::class), |
|
| 1134 | + ); |
|
| 1135 | + $backgroundJob->start($this->jobList); |
|
| 1136 | + |
|
| 1137 | + $task = $this->manager->getTask($task->getId()); |
|
| 1138 | + self::assertEquals(Task::STATUS_SUCCESSFUL, $task->getStatus()); |
|
| 1139 | + self::assertIsArray($task->getOutput()); |
|
| 1140 | + self::assertTrue(isset($task->getOutput()['images'])); |
|
| 1141 | + self::assertIsArray($task->getOutput()['images']); |
|
| 1142 | + self::assertCount(3, $task->getOutput()['images']); |
|
| 1143 | + self::assertTrue($this->providers[SuccessfulTextToImageProvider::class]->ran); |
|
| 1144 | + $node = $this->rootFolder->getFirstNodeByIdInPath($task->getOutput()['images'][0], '/' . $this->rootFolder->getAppDataDirectoryName() . '/'); |
|
| 1145 | + self::assertNotNull($node); |
|
| 1146 | + self::assertInstanceOf(File::class, $node); |
|
| 1147 | + self::assertEquals('test', $node->getContent()); |
|
| 1148 | + } |
|
| 1149 | + |
|
| 1150 | + public function testShouldTransparentlyHandleFailingText2ImageProviders(): void { |
|
| 1151 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([ |
|
| 1152 | + new ServiceRegistration('test', FailingTextToImageProvider::class) |
|
| 1153 | + ]); |
|
| 1154 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1155 | + ]); |
|
| 1156 | + $taskTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1157 | + self::assertCount(1, $taskTypes); |
|
| 1158 | + self::assertCount(1, $this->manager->getAvailableTaskTypeIds()); |
|
| 1159 | + self::assertTrue(isset($taskTypes[TextToImage::ID])); |
|
| 1160 | + self::assertTrue($this->manager->hasProviders()); |
|
| 1161 | + $task = new Task(TextToImage::ID, ['input' => 'Hello', 'numberOfImages' => 3], 'test', null); |
|
| 1162 | + $this->manager->scheduleTask($task); |
|
| 1163 | + |
|
| 1164 | + $this->eventDispatcher->expects($this->once())->method('dispatchTyped')->with(new IsInstanceOf(TaskFailedEvent::class)); |
|
| 1165 | + |
|
| 1166 | + $backgroundJob = new SynchronousBackgroundJob( |
|
| 1167 | + Server::get(ITimeFactory::class), |
|
| 1168 | + $this->manager, |
|
| 1169 | + $this->jobList, |
|
| 1170 | + Server::get(LoggerInterface::class), |
|
| 1171 | + ); |
|
| 1172 | + $backgroundJob->start($this->jobList); |
|
| 1173 | + |
|
| 1174 | + $task = $this->manager->getTask($task->getId()); |
|
| 1175 | + self::assertEquals(Task::STATUS_FAILED, $task->getStatus()); |
|
| 1176 | + self::assertTrue($task->getOutput() === null); |
|
| 1177 | + self::assertEquals('ERROR', $task->getErrorMessage()); |
|
| 1178 | + self::assertTrue($this->providers[FailingTextToImageProvider::class]->ran); |
|
| 1179 | + } |
|
| 1180 | + |
|
| 1181 | + public function testMergeProvidersLocalAndEvent() { |
|
| 1182 | + // Arrange: Local provider registered, DIFFERENT external provider via event |
|
| 1183 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1184 | + new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 1185 | + ]); |
|
| 1186 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1187 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1188 | + $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1189 | + |
|
| 1190 | + $externalProvider = new ExternalProvider(); // ID = 'event:external:provider' |
|
| 1191 | + $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1192 | + $this->manager = $this->createManagerInstance(); |
|
| 1193 | + |
|
| 1194 | + // Act |
|
| 1195 | + $providers = $this->manager->getProviders(); |
|
| 1196 | + |
|
| 1197 | + // Assert: Both providers should be present |
|
| 1198 | + self::assertArrayHasKey(SuccessfulSyncProvider::ID, $providers); |
|
| 1199 | + self::assertInstanceOf(SuccessfulSyncProvider::class, $providers[SuccessfulSyncProvider::ID]); |
|
| 1200 | + self::assertArrayHasKey(ExternalProvider::ID, $providers); |
|
| 1201 | + self::assertInstanceOf(ExternalProvider::class, $providers[ExternalProvider::ID]); |
|
| 1202 | + self::assertCount(2, $providers); |
|
| 1203 | + } |
|
| 1204 | + |
|
| 1205 | + public function testGetProvidersIncludesExternalViaEvent() { |
|
| 1206 | + // Arrange: No local providers, one external provider via event |
|
| 1207 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1208 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1209 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1210 | + $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1211 | + |
|
| 1212 | + |
|
| 1213 | + $externalProvider = new ExternalProvider(); |
|
| 1214 | + $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1215 | + $this->manager = $this->createManagerInstance(); // Create manager with configured mocks |
|
| 1216 | + |
|
| 1217 | + // Act |
|
| 1218 | + $providers = $this->manager->getProviders(); // Returns ID-indexed array |
|
| 1219 | + |
|
| 1220 | + // Assert |
|
| 1221 | + self::assertArrayHasKey(ExternalProvider::ID, $providers); |
|
| 1222 | + self::assertInstanceOf(ExternalProvider::class, $providers[ExternalProvider::ID]); |
|
| 1223 | + self::assertCount(1, $providers); |
|
| 1224 | + self::assertTrue($this->manager->hasProviders()); |
|
| 1225 | + } |
|
| 1226 | + |
|
| 1227 | + public function testGetAvailableTaskTypesIncludesExternalViaEvent() { |
|
| 1228 | + // Arrange: No local types/providers, one external type and provider via event |
|
| 1229 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1230 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([]); |
|
| 1231 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1232 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1233 | + $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1234 | + |
|
| 1235 | + $externalProvider = new ExternalProvider(); // Provides ExternalTaskType |
|
| 1236 | + $externalTaskType = new ExternalTaskType(); |
|
| 1237 | + $this->configureEventDispatcherMock( |
|
| 1238 | + providersToAdd: [$externalProvider], |
|
| 1239 | + taskTypesToAdd: [$externalTaskType] |
|
| 1240 | + ); |
|
| 1241 | + $this->manager = $this->createManagerInstance(); |
|
| 1242 | + |
|
| 1243 | + // Act |
|
| 1244 | + $availableTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1245 | + |
|
| 1246 | + // Assert |
|
| 1247 | + self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes); |
|
| 1248 | + self::assertContains(ExternalTaskType::ID, $this->manager->getAvailableTaskTypeIds()); |
|
| 1249 | + self::assertEquals(ExternalTaskType::ID, $externalProvider->getTaskTypeId(), 'Test Sanity: Provider must handle the Task Type'); |
|
| 1250 | + self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']); |
|
| 1251 | + // Check if shapes match the external type/provider |
|
| 1252 | + self::assertArrayHasKey('external_input', $availableTypes[ExternalTaskType::ID]['inputShape']); |
|
| 1253 | + self::assertArrayHasKey('external_output', $availableTypes[ExternalTaskType::ID]['outputShape']); |
|
| 1254 | + self::assertEmpty($availableTypes[ExternalTaskType::ID]['optionalInputShape']); // From ExternalProvider |
|
| 1255 | + } |
|
| 1256 | + |
|
| 1257 | + public function testLocalProviderWinsConflictWithEvent() { |
|
| 1258 | + // Arrange: Local provider registered, conflicting external provider via event |
|
| 1259 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1260 | + new ServiceRegistration('test', SuccessfulSyncProvider::class) |
|
| 1261 | + ]); |
|
| 1262 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1263 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1264 | + $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1265 | + |
|
| 1266 | + $conflictingExternalProvider = new ConflictingExternalProvider(); // ID = 'test:sync:success' |
|
| 1267 | + $this->configureEventDispatcherMock(providersToAdd: [$conflictingExternalProvider]); |
|
| 1268 | + $this->manager = $this->createManagerInstance(); |
|
| 1269 | + |
|
| 1270 | + // Act |
|
| 1271 | + $providers = $this->manager->getProviders(); |
|
| 1272 | + |
|
| 1273 | + // Assert: Only the local provider should be present for the conflicting ID |
|
| 1274 | + self::assertArrayHasKey(SuccessfulSyncProvider::ID, $providers); |
|
| 1275 | + self::assertInstanceOf(SuccessfulSyncProvider::class, $providers[SuccessfulSyncProvider::ID]); |
|
| 1276 | + self::assertCount(1, $providers); // Ensure no extra provider was added |
|
| 1277 | + } |
|
| 1278 | + |
|
| 1279 | + public function testTriggerableProviderWithNoOtherRunningTasks() { |
|
| 1280 | + // Arrange: Local provider registered, conflicting external provider via event |
|
| 1281 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1282 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1283 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1284 | + $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1285 | + |
|
| 1286 | + $externalProvider = $this->createPartialMock(ExternalTriggerableProvider::class, ['trigger']); |
|
| 1287 | + $externalProvider->expects($this->once())->method('trigger'); |
|
| 1288 | + $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1289 | + $this->manager = $this->createManagerInstance(); |
|
| 1290 | + |
|
| 1291 | + // Act |
|
| 1292 | + $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar'); |
|
| 1293 | + $this->manager->scheduleTask($task); |
|
| 1294 | + } |
|
| 1295 | + |
|
| 1296 | + public function testTriggerableProviderWithOtherRunningTasks() { |
|
| 1297 | + // Arrange: Local provider registered, conflicting external provider via event |
|
| 1298 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([]); |
|
| 1299 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1300 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1301 | + $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1302 | + |
|
| 1303 | + $externalProvider = $this->createPartialMock(ExternalTriggerableProvider::class, ['trigger']); |
|
| 1304 | + $externalProvider->expects($this->once())->method('trigger'); |
|
| 1305 | + $this->configureEventDispatcherMock(providersToAdd: [$externalProvider]); |
|
| 1306 | + $this->manager = $this->createManagerInstance(); |
|
| 1307 | + |
|
| 1308 | + $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar'); |
|
| 1309 | + $this->manager->scheduleTask($task); |
|
| 1310 | + $this->manager->lockTask($task); |
|
| 1311 | + |
|
| 1312 | + // Act |
|
| 1313 | + $task = new Task($externalProvider->getTaskTypeId(), ['input' => ''], 'tests', 'foobar'); |
|
| 1314 | + $this->manager->scheduleTask($task); |
|
| 1315 | + } |
|
| 1316 | + |
|
| 1317 | + public function testMergeTaskTypesLocalAndEvent() { |
|
| 1318 | + // Arrange: Local type registered, DIFFERENT external type via event |
|
| 1319 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingProviders')->willReturn([ |
|
| 1320 | + new ServiceRegistration('test', AsyncProvider::class) |
|
| 1321 | + ]); |
|
| 1322 | + $this->registrationContext->expects($this->any())->method('getTaskProcessingTaskTypes')->willReturn([ |
|
| 1323 | + new ServiceRegistration('test', AudioToImage::class) |
|
| 1324 | + ]); |
|
| 1325 | + $this->registrationContext->expects($this->any())->method('getTextProcessingProviders')->willReturn([]); |
|
| 1326 | + $this->registrationContext->expects($this->any())->method('getTextToImageProviders')->willReturn([]); |
|
| 1327 | + $this->registrationContext->expects($this->any())->method('getSpeechToTextProviders')->willReturn([]); |
|
| 1328 | + |
|
| 1329 | + $externalTaskType = new ExternalTaskType(); // ID = 'event:external:tasktype' |
|
| 1330 | + $externalProvider = new ExternalProvider(); // Handles 'event:external:tasktype' |
|
| 1331 | + $this->configureEventDispatcherMock( |
|
| 1332 | + providersToAdd: [$externalProvider], |
|
| 1333 | + taskTypesToAdd: [$externalTaskType] |
|
| 1334 | + ); |
|
| 1335 | + $this->manager = $this->createManagerInstance(); |
|
| 1336 | + |
|
| 1337 | + // Act |
|
| 1338 | + $availableTypes = $this->manager->getAvailableTaskTypes(); |
|
| 1339 | + $availableTypeIds = $this->manager->getAvailableTaskTypeIds(); |
|
| 1340 | + |
|
| 1341 | + // Assert: Both task types should be available |
|
| 1342 | + self::assertContains(AudioToImage::ID, $availableTypeIds); |
|
| 1343 | + self::assertArrayHasKey(AudioToImage::ID, $availableTypes); |
|
| 1344 | + self::assertEquals(AudioToImage::class, $availableTypes[AudioToImage::ID]['name']); |
|
| 1345 | + |
|
| 1346 | + self::assertContains(ExternalTaskType::ID, $availableTypeIds); |
|
| 1347 | + self::assertArrayHasKey(ExternalTaskType::ID, $availableTypes); |
|
| 1348 | + self::assertEquals('External Task Type via Event', $availableTypes[ExternalTaskType::ID]['name']); |
|
| 1349 | + |
|
| 1350 | + self::assertCount(2, $availableTypes); |
|
| 1351 | + } |
|
| 1352 | + |
|
| 1353 | + private function createManagerInstance(): Manager { |
|
| 1354 | + // Clear potentially cached config values if needed |
|
| 1355 | + $this->appConfig->deleteKey('core', 'ai.taskprocessing_type_preferences'); |
|
| 1356 | + |
|
| 1357 | + // Re-create Text2ImageManager if its state matters or mocks change |
|
| 1358 | + $text2imageManager = new \OC\TextToImage\Manager( |
|
| 1359 | + $this->serverContainer, |
|
| 1360 | + $this->coordinator, |
|
| 1361 | + Server::get(LoggerInterface::class), |
|
| 1362 | + $this->jobList, |
|
| 1363 | + Server::get(\OC\TextToImage\Db\TaskMapper::class), |
|
| 1364 | + $this->config, // Use the shared config mock |
|
| 1365 | + Server::get(IAppDataFactory::class), |
|
| 1366 | + ); |
|
| 1367 | + |
|
| 1368 | + return new Manager( |
|
| 1369 | + $this->appConfig, |
|
| 1370 | + $this->coordinator, |
|
| 1371 | + $this->serverContainer, |
|
| 1372 | + Server::get(LoggerInterface::class), |
|
| 1373 | + $this->taskMapper, |
|
| 1374 | + $this->jobList, |
|
| 1375 | + $this->eventDispatcher, // Use the potentially reconfigured mock |
|
| 1376 | + Server::get(IAppDataFactory::class), |
|
| 1377 | + $this->rootFolder, |
|
| 1378 | + $text2imageManager, |
|
| 1379 | + $this->userMountCache, |
|
| 1380 | + Server::get(IClientService::class), |
|
| 1381 | + Server::get(IAppManager::class), |
|
| 1382 | + Server::get(IUserManager::class), |
|
| 1383 | + Server::get(IUserSession::class), |
|
| 1384 | + Server::get(ICacheFactory::class), |
|
| 1385 | + Server::get(IFactory::class), |
|
| 1386 | + ); |
|
| 1387 | + } |
|
| 1388 | + |
|
| 1389 | + private function configureEventDispatcherMock( |
|
| 1390 | + array $providersToAdd = [], |
|
| 1391 | + array $taskTypesToAdd = [], |
|
| 1392 | + ?int $expectedCalls = null, |
|
| 1393 | + ): void { |
|
| 1394 | + $dispatchExpectation = $expectedCalls === null ? $this->any() : $this->exactly($expectedCalls); |
|
| 1395 | + |
|
| 1396 | + $this->eventDispatcher->expects($dispatchExpectation) |
|
| 1397 | + ->method('dispatchTyped') |
|
| 1398 | + ->willReturnCallback(function (object $event) use ($providersToAdd, $taskTypesToAdd): void { |
|
| 1399 | + if ($event instanceof GetTaskProcessingProvidersEvent) { |
|
| 1400 | + foreach ($providersToAdd as $providerInstance) { |
|
| 1401 | + $event->addProvider($providerInstance); |
|
| 1402 | + } |
|
| 1403 | + foreach ($taskTypesToAdd as $taskTypeInstance) { |
|
| 1404 | + $event->addTaskType($taskTypeInstance); |
|
| 1405 | + } |
|
| 1406 | + } |
|
| 1407 | + }); |
|
| 1408 | + } |
|
| 1409 | 1409 | } |